diff --git a/tensorflow/core/common_runtime/composite_device.h b/tensorflow/core/common_runtime/composite_device.h index c68c395198a..364fae1b118 100644 --- a/tensorflow/core/common_runtime/composite_device.h +++ b/tensorflow/core/common_runtime/composite_device.h @@ -53,6 +53,8 @@ class CompositeDevice : public Device { const std::vector<string>& underlying_devices, const string& device_name, Status* status); + bool IsRemoteCallAllowed() const override { return false; } + private: CompositeDevice(const DeviceAttributes& device_attributes, const std::vector<string>& underlying_devices) diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 40c31185eac..50f3b52e4c6 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -200,13 +200,15 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext( // "TPU_SYSTEM" indicates that `device` is a CPU. return Status::OK(); } - if (device_type == "GPU" || device_type == "TPU") { + + if (device->IsRemoteCallAllowed()) { auto* dev_info = flr->device()->tensorflow_gpu_device_info(); if (dev_info) { *device_context = dev_info->default_context; return Status::OK(); } } + return errors::Internal("Device type: ", device_type, " is currently unsupported for remote ", "function executions"); diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h index 9a7c730c1fb..36792bd5a33 100644 --- a/tensorflow/core/common_runtime/renamed_device.h +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -149,6 +149,10 @@ class RenamedDevice : public Device { bool IsLocal() const override { return underlying_device_->IsLocal(); } + bool IsRemoteCallAllowed() const override { + return underlying_device_->IsRemoteCallAllowed(); + } + private: RenamedDevice(Device* underlying, const DeviceAttributes& attributes, bool owns_underlying, bool isolate_session_state, diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc index bb9b074858a..dd3a0ec3521 100644 --- a/tensorflow/core/distributed_runtime/remote_device.cc +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -50,6 +50,8 @@ class RemoteDevice : public Device { bool IsLocal() const override { return false; } + bool IsRemoteCallAllowed() const override { return true; } + private: const string local_dev_name_; diff --git a/tensorflow/core/framework/device.cc b/tensorflow/core/framework/device.cc index 50453822230..eedb08c39c3 100644 --- a/tensorflow/core/framework/device.cc +++ b/tensorflow/core/framework/device.cc @@ -54,4 +54,21 @@ DeviceAttributes Device::BuildDeviceAttributes( return da; } +bool Device::IsRemoteCallAllowed() const { + auto& type = parsed_name_.type; + if (type == "TPU") { + return true; + } + if (type == "TPU_SYSTEM") { + return true; + } + if (type == "CPU") { + return true; + } + if (type == "GPU") { + return true; + } + return false; +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/device.h b/tensorflow/core/framework/device.h index 0f544bdd123..bdd671779fd 100644 --- a/tensorflow/core/framework/device.h +++ b/tensorflow/core/framework/device.h @@ -178,6 +178,9 @@ class Device : public DeviceBase { virtual bool IsLocal() const { return true; } + // Informs if this Device can be used as a caller in RemoteCall operation. + virtual bool IsRemoteCallAllowed() const; + protected: void DeleteResourceMgr() { delete rmgr_;