From 747c37add59d0dd2b96f1f552baf80e63f70ff1c Mon Sep 17 00:00:00 2001 From: Cesar Crusius Date: Thu, 9 Apr 2020 14:28:01 -0700 Subject: [PATCH] Make the VariantDevice type visible at the same level of CustomDevice. In practice this is already done by expanding its absl::variant<> definition in a handful of places. By making the type visible we can properly account for its usage. PiperOrigin-RevId: 305760610 Change-Id: I95d65461ebb70c2d4e33eb59985b01d6cb18554e --- tensorflow/core/common_runtime/eager/context.h | 5 +++++ tensorflow/core/common_runtime/eager/core.cc | 3 +-- .../core/common_runtime/eager/eager_operation.cc | 2 +- .../core/common_runtime/eager/eager_operation.h | 6 ++---- tensorflow/core/common_runtime/eager/execute.cc | 3 +-- .../core/common_runtime/eager/execute_node.cc | 3 +-- .../core/common_runtime/eager/tensor_handle.cc | 10 ++++------ .../core/common_runtime/eager/tensor_handle.h | 14 ++++---------- .../core/distributed_runtime/eager/remote_mgr.cc | 2 +- 9 files changed, 20 insertions(+), 28 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 877d8072008..1670345efd5 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -125,6 +125,11 @@ class CustomDevice { int* num_retvals) = 0; }; +// Custom devices do many of the same things as physical Devices, but have a +// much more restricted interface. We pass around ambiguous pointers since +// TensorHandles may be placed either on custom or physical devices. +using VariantDevice = absl::variant; + class EagerContext : public AbstractContextInterface, public core::RefCounted { public: static const uint64 kInvalidContextId = 0; diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc index de7e7475a1c..cfb188bdd77 100644 --- a/tensorflow/core/common_runtime/eager/core.cc +++ b/tensorflow/core/common_runtime/eager/core.cc @@ -21,8 +21,7 @@ limitations under the License. namespace { -bool IsCPU( - absl::variant variant) { +bool IsCPU(tensorflow::VariantDevice variant) { if (VariantDeviceIsCustom(variant)) { return false; } diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 3804f5164d4..d3a31278326 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -37,7 +37,7 @@ void EagerOperation::Clear() { } const string& EagerOperation::DeviceName() const { - absl::variant variant_device = + VariantDevice variant_device = (Device() == kVariantDeviceNull) ? EagerContext().HostCPU() : Device(); return absl::visit([](auto* d) -> const string& { return d->name(); }, variant_device); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index 550881c571b..d1128977ace 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -119,9 +119,7 @@ class EagerOperation : public AbstractOperationInterface { // Like TensorHandles, EagerOperations may be placed either on a virtual // CustomDevice or on a physical Device. - absl::variant Device() const { - return device_; - } + VariantDevice Device() const { return device_; } void SetDevice(tensorflow::Device* device) { device_ = device; @@ -185,7 +183,7 @@ class EagerOperation : public AbstractOperationInterface { AttrBuilder attrs_; const AttrTypeMap* attr_types_; absl::InlinedVector inputs_; - absl::variant device_; + VariantDevice device_; string raw_device_name_; string device_name_; DeviceNameUtils::ParsedName device_parsed_name_; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 10c9c5ef54f..8c602b0f498 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -80,8 +80,7 @@ const string& DeviceNameOrUnspecified(Device* device) { return (device == nullptr) ? *unspecified_string : device->name(); } -const string& DeviceNameOrUnspecified( - absl::variant device) { +const string& DeviceNameOrUnspecified(VariantDevice device) { if (VariantDeviceIsCustom(device)) { return absl::get(device)->name(); } else { diff --git a/tensorflow/core/common_runtime/eager/execute_node.cc b/tensorflow/core/common_runtime/eager/execute_node.cc index 5ced006fb9e..f2528081877 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.cc +++ b/tensorflow/core/common_runtime/eager/execute_node.cc @@ -53,8 +53,7 @@ Status ExecuteNodeArgs::Init( serialize_remote_handle_ = [ctx, &op_inputs](const int i, eager::RemoteTensorHandle* handle) -> Status { - absl::variant variant_device = - op_inputs[i]->device(); + VariantDevice variant_device = op_inputs[i]->device(); if (VariantDeviceIsCustom(variant_device)) { return errors::Internal( "Custom devices and remote execution are currently not supported " diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 2cbb978b5ee..858d0a338ae 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -326,8 +326,7 @@ Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) { return mirror.TensorValue(t); } -TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU( - const EagerContext& ctx) const { +VariantDevice TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const { if (VariantDeviceIsCustom(device_)) { return device_; } else { @@ -788,16 +787,15 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx, return status; } -bool VariantDeviceIsCustom( - absl::variant variant_device) { +bool VariantDeviceIsCustom(VariantDevice variant_device) { return variant_device.index() != 0; } -string VariantDeviceName(absl::variant device) { +string VariantDeviceName(VariantDevice device) { return absl::visit([](auto* device) { return device->name(); }, device); } -string VariantDeviceDebugString(absl::variant device) { +string VariantDeviceDebugString(VariantDevice device) { if (device == kVariantDeviceNull) { return "[]"; } else if (VariantDeviceIsCustom(device)) { diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 9309b4fcccd..0b39161af73 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -55,11 +55,6 @@ class EagerContext; // (unrelated to python TensorHandle). class TensorHandle : public AbstractTensorHandleInterface, public core::RefCounted { - // Custom devices do many of the same things as physical Devices, but have a - // much more restricted interface. We pass around ambiguous pointers since - // TensorHandles may be placed either on custom or physical devices. - using VariantDevice = absl::variant; - // TensorHandle for dtype != DT_RESOURCE TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, Device* resource_device, EagerContext* ctx); @@ -291,18 +286,17 @@ class TensorHandle : public AbstractTensorHandleInterface, }; // Checks whether a VariantDevice contains a custom device. -bool VariantDeviceIsCustom(absl::variant device); +bool VariantDeviceIsCustom(VariantDevice device); // Wraps device->name() or CustomDevice->name(). -string VariantDeviceName(absl::variant device); +string VariantDeviceName(VariantDevice device); // Wraps device->DebugString() or CustomDevice->name(). -string VariantDeviceDebugString(absl::variant device); +string VariantDeviceDebugString(VariantDevice device); // Indicates either HostCPU or an unset physical device. We never set a null // CustomDevice*. -const absl::variant kVariantDeviceNull = - static_cast(nullptr); +const VariantDevice kVariantDeviceNull = static_cast(nullptr); // Returns the device backing the resource. Else, returns nullptr. Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx); diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc index ef3d42de037..c120a28032c 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc @@ -76,7 +76,7 @@ Status RemoteMgr::GetMirroredResourceShape( Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, int64* op_id, int32* output_num) { // TODO(allenl): Consider supporting remote handles on custom devices. - absl::variant device = handle->device(); + VariantDevice device = handle->device(); if (VariantDeviceIsCustom(device)) { return errors::Unimplemented( "Custom devices and remote execution are currently not supported "