From 35d8bef702aa8d4b1e6b37f2b7e3493e4a4f18c3 Mon Sep 17 00:00:00 2001 From: Pawel Piskorski Date: Mon, 31 Aug 2020 19:12:57 +0200 Subject: [PATCH] replace PFLR DeviceGetContext hardcode with Device::IsRemoteCallAllowed This delegates decision of whether a particular device type can host a remote call to the device implementation itself. Change-Id: I189e9d965ade7386b0c24da6c90d369b1ed72c3e --- .../core/common_runtime/composite_device.h | 2 ++ .../process_function_library_runtime.cc | 4 +++- .../core/common_runtime/renamed_device.h | 2 ++ .../core/distributed_runtime/remote_device.cc | 2 ++ tensorflow/core/framework/device.cc | 18 ++++++++++++++++++ tensorflow/core/framework/device.h | 3 +++ 6 files changed, 30 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/composite_device.h b/tensorflow/core/common_runtime/composite_device.h index c68c395198a..364fae1b118 100644 --- a/tensorflow/core/common_runtime/composite_device.h +++ b/tensorflow/core/common_runtime/composite_device.h @@ -53,6 +53,8 @@ class CompositeDevice : public Device { const std::vector& underlying_devices, const string& device_name, Status* status); + bool IsRemoteCallAllowed() const override { return false; } + private: CompositeDevice(const DeviceAttributes& device_attributes, const std::vector& underlying_devices) diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 40c31185eac..50f3b52e4c6 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -200,13 +200,15 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext( // "TPU_SYSTEM" indicates that `device` is a CPU. return Status::OK(); } - if (device_type == "GPU" || device_type == "TPU") { + + if (device->IsRemoteCallAllowed()) { auto* dev_info = flr->device()->tensorflow_gpu_device_info(); if (dev_info) { *device_context = dev_info->default_context; return Status::OK(); } } + return errors::Internal("Device type: ", device_type, " is currently unsupported for remote ", "function executions"); diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h index 9a7c730c1fb..56350dacbc7 100644 --- a/tensorflow/core/common_runtime/renamed_device.h +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -149,6 +149,8 @@ class RenamedDevice : public Device { bool IsLocal() const override { return underlying_device_->IsLocal(); } + bool IsRemoteCallAllowed() const override { return underlying_device_->IsRemoteCallAllowed(); } + private: RenamedDevice(Device* underlying, const DeviceAttributes& attributes, bool owns_underlying, bool isolate_session_state, diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc index bb9b074858a..dd3a0ec3521 100644 --- a/tensorflow/core/distributed_runtime/remote_device.cc +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -50,6 +50,8 @@ class RemoteDevice : public Device { bool IsLocal() const override { return false; } + bool IsRemoteCallAllowed() const override { return true; } + private: const string local_dev_name_; diff --git a/tensorflow/core/framework/device.cc b/tensorflow/core/framework/device.cc index 50453822230..1b79a35e3c7 100644 --- a/tensorflow/core/framework/device.cc +++ b/tensorflow/core/framework/device.cc @@ -54,4 +54,22 @@ DeviceAttributes Device::BuildDeviceAttributes( return da; } + +bool Device::IsRemoteCallAllowed() const { + auto &type = parsed_name_.type; + if (type == "TPU") { + return true; + } + if (type == "TPU_SYSTEM") { + return true; + } + if (type == "CPU") { + return true; + } + if (type == "GPU") { + return true; + } + return false; +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/device.h b/tensorflow/core/framework/device.h index 0f544bdd123..bdd671779fd 100644 --- a/tensorflow/core/framework/device.h +++ b/tensorflow/core/framework/device.h @@ -178,6 +178,9 @@ class Device : public DeviceBase { virtual bool IsLocal() const { return true; } + // Informs if this Device can be used as a caller in RemoteCall operation. + virtual bool IsRemoteCallAllowed() const; + protected: void DeleteResourceMgr() { delete rmgr_;