Make the VariantDevice type visible at the same level of CustomDevice.

In practice this is already done by expanding its absl::variant<>
definition in a handful of places. By making the type visible we can
properly account for its usage.

PiperOrigin-RevId: 305760610
Change-Id: I95d65461ebb70c2d4e33eb59985b01d6cb18554e
This commit is contained in:
Cesar Crusius 2020-04-09 14:28:01 -07:00 committed by TensorFlower Gardener
parent 7d5d21afd1
commit 747c37add5
9 changed files with 20 additions and 28 deletions

View File

@ -125,6 +125,11 @@ class CustomDevice {
int* num_retvals) = 0;
};
// Custom devices do many of the same things as physical Devices, but have a
// much more restricted interface. We pass around ambiguous pointers since
// TensorHandles may be placed either on custom or physical devices.
using VariantDevice = absl::variant<Device*, CustomDevice*>;
class EagerContext : public AbstractContextInterface, public core::RefCounted {
public:
static const uint64 kInvalidContextId = 0;

View File

@ -21,8 +21,7 @@ limitations under the License.
namespace {
bool IsCPU(
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
bool IsCPU(tensorflow::VariantDevice variant) {
if (VariantDeviceIsCustom(variant)) {
return false;
}

View File

@ -37,7 +37,7 @@ void EagerOperation::Clear() {
}
const string& EagerOperation::DeviceName() const {
absl::variant<tensorflow::Device*, CustomDevice*> variant_device =
VariantDevice variant_device =
(Device() == kVariantDeviceNull) ? EagerContext().HostCPU() : Device();
return absl::visit([](auto* d) -> const string& { return d->name(); },
variant_device);

View File

@ -119,9 +119,7 @@ class EagerOperation : public AbstractOperationInterface {
// Like TensorHandles, EagerOperations may be placed either on a virtual
// CustomDevice or on a physical Device.
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> Device() const {
return device_;
}
VariantDevice Device() const { return device_; }
void SetDevice(tensorflow::Device* device) {
device_ = device;
@ -185,7 +183,7 @@ class EagerOperation : public AbstractOperationInterface {
AttrBuilder attrs_;
const AttrTypeMap* attr_types_;
absl::InlinedVector<TensorHandle*, 4> inputs_;
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> device_;
VariantDevice device_;
string raw_device_name_;
string device_name_;
DeviceNameUtils::ParsedName device_parsed_name_;

View File

@ -80,8 +80,7 @@ const string& DeviceNameOrUnspecified(Device* device) {
return (device == nullptr) ? *unspecified_string : device->name();
}
const string& DeviceNameOrUnspecified(
absl::variant<Device*, CustomDevice*> device) {
const string& DeviceNameOrUnspecified(VariantDevice device) {
if (VariantDeviceIsCustom(device)) {
return absl::get<CustomDevice*>(device)->name();
} else {

View File

@ -53,8 +53,7 @@ Status ExecuteNodeArgs::Init(
serialize_remote_handle_ =
[ctx, &op_inputs](const int i,
eager::RemoteTensorHandle* handle) -> Status {
absl::variant<Device*, CustomDevice*> variant_device =
op_inputs[i]->device();
VariantDevice variant_device = op_inputs[i]->device();
if (VariantDeviceIsCustom(variant_device)) {
return errors::Internal(
"Custom devices and remote execution are currently not supported "

View File

@ -326,8 +326,7 @@ Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
return mirror.TensorValue(t);
}
TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU(
const EagerContext& ctx) const {
VariantDevice TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
if (VariantDeviceIsCustom(device_)) {
return device_;
} else {
@ -788,16 +787,15 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx,
return status;
}
bool VariantDeviceIsCustom(
absl::variant<Device*, CustomDevice*> variant_device) {
bool VariantDeviceIsCustom(VariantDevice variant_device) {
return variant_device.index() != 0;
}
string VariantDeviceName(absl::variant<Device*, CustomDevice*> device) {
string VariantDeviceName(VariantDevice device) {
return absl::visit([](auto* device) { return device->name(); }, device);
}
string VariantDeviceDebugString(absl::variant<Device*, CustomDevice*> device) {
string VariantDeviceDebugString(VariantDevice device) {
if (device == kVariantDeviceNull) {
return "[]";
} else if (VariantDeviceIsCustom(device)) {

View File

@ -55,11 +55,6 @@ class EagerContext;
// (unrelated to python TensorHandle).
class TensorHandle : public AbstractTensorHandleInterface,
public core::RefCounted {
// Custom devices do many of the same things as physical Devices, but have a
// much more restricted interface. We pass around ambiguous pointers since
// TensorHandles may be placed either on custom or physical devices.
using VariantDevice = absl::variant<Device*, CustomDevice*>;
// TensorHandle for dtype != DT_RESOURCE
TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
Device* resource_device, EagerContext* ctx);
@ -291,18 +286,17 @@ class TensorHandle : public AbstractTensorHandleInterface,
};
// Checks whether a VariantDevice contains a custom device.
bool VariantDeviceIsCustom(absl::variant<Device*, CustomDevice*> device);
bool VariantDeviceIsCustom(VariantDevice device);
// Wraps device->name() or CustomDevice->name().
string VariantDeviceName(absl::variant<Device*, CustomDevice*> device);
string VariantDeviceName(VariantDevice device);
// Wraps device->DebugString() or CustomDevice->name().
string VariantDeviceDebugString(absl::variant<Device*, CustomDevice*> device);
string VariantDeviceDebugString(VariantDevice device);
// Indicates either HostCPU or an unset physical device. We never set a null
// CustomDevice*.
const absl::variant<Device*, CustomDevice*> kVariantDeviceNull =
static_cast<Device*>(nullptr);
const VariantDevice kVariantDeviceNull = static_cast<Device*>(nullptr);
// Returns the device backing the resource. Else, returns nullptr.
Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);

View File

@ -76,7 +76,7 @@ Status RemoteMgr::GetMirroredResourceShape(
Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
int64* op_id, int32* output_num) {
// TODO(allenl): Consider supporting remote handles on custom devices.
absl::variant<Device*, CustomDevice*> device = handle->device();
VariantDevice device = handle->device();
if (VariantDeviceIsCustom(device)) {
return errors::Unimplemented(
"Custom devices and remote execution are currently not supported "