Make the VariantDevice type visible at the same level of CustomDevice.
In practice this is already done by expanding its absl::variant<> definition in a handful of places. By making the type visible we can properly account for its usage. PiperOrigin-RevId: 305760610 Change-Id: I95d65461ebb70c2d4e33eb59985b01d6cb18554e
This commit is contained in:
parent
7d5d21afd1
commit
747c37add5
tensorflow/core
common_runtime/eager
context.hcore.cceager_operation.cceager_operation.hexecute.ccexecute_node.cctensor_handle.cctensor_handle.h
distributed_runtime/eager
@ -125,6 +125,11 @@ class CustomDevice {
|
||||
int* num_retvals) = 0;
|
||||
};
|
||||
|
||||
// Custom devices do many of the same things as physical Devices, but have a
|
||||
// much more restricted interface. We pass around ambiguous pointers since
|
||||
// TensorHandles may be placed either on custom or physical devices.
|
||||
using VariantDevice = absl::variant<Device*, CustomDevice*>;
|
||||
|
||||
class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
public:
|
||||
static const uint64 kInvalidContextId = 0;
|
||||
|
@ -21,8 +21,7 @@ limitations under the License.
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsCPU(
|
||||
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
|
||||
bool IsCPU(tensorflow::VariantDevice variant) {
|
||||
if (VariantDeviceIsCustom(variant)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ void EagerOperation::Clear() {
|
||||
}
|
||||
|
||||
const string& EagerOperation::DeviceName() const {
|
||||
absl::variant<tensorflow::Device*, CustomDevice*> variant_device =
|
||||
VariantDevice variant_device =
|
||||
(Device() == kVariantDeviceNull) ? EagerContext().HostCPU() : Device();
|
||||
return absl::visit([](auto* d) -> const string& { return d->name(); },
|
||||
variant_device);
|
||||
|
@ -119,9 +119,7 @@ class EagerOperation : public AbstractOperationInterface {
|
||||
|
||||
// Like TensorHandles, EagerOperations may be placed either on a virtual
|
||||
// CustomDevice or on a physical Device.
|
||||
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> Device() const {
|
||||
return device_;
|
||||
}
|
||||
VariantDevice Device() const { return device_; }
|
||||
|
||||
void SetDevice(tensorflow::Device* device) {
|
||||
device_ = device;
|
||||
@ -185,7 +183,7 @@ class EagerOperation : public AbstractOperationInterface {
|
||||
AttrBuilder attrs_;
|
||||
const AttrTypeMap* attr_types_;
|
||||
absl::InlinedVector<TensorHandle*, 4> inputs_;
|
||||
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> device_;
|
||||
VariantDevice device_;
|
||||
string raw_device_name_;
|
||||
string device_name_;
|
||||
DeviceNameUtils::ParsedName device_parsed_name_;
|
||||
|
@ -80,8 +80,7 @@ const string& DeviceNameOrUnspecified(Device* device) {
|
||||
return (device == nullptr) ? *unspecified_string : device->name();
|
||||
}
|
||||
|
||||
const string& DeviceNameOrUnspecified(
|
||||
absl::variant<Device*, CustomDevice*> device) {
|
||||
const string& DeviceNameOrUnspecified(VariantDevice device) {
|
||||
if (VariantDeviceIsCustom(device)) {
|
||||
return absl::get<CustomDevice*>(device)->name();
|
||||
} else {
|
||||
|
@ -53,8 +53,7 @@ Status ExecuteNodeArgs::Init(
|
||||
serialize_remote_handle_ =
|
||||
[ctx, &op_inputs](const int i,
|
||||
eager::RemoteTensorHandle* handle) -> Status {
|
||||
absl::variant<Device*, CustomDevice*> variant_device =
|
||||
op_inputs[i]->device();
|
||||
VariantDevice variant_device = op_inputs[i]->device();
|
||||
if (VariantDeviceIsCustom(variant_device)) {
|
||||
return errors::Internal(
|
||||
"Custom devices and remote execution are currently not supported "
|
||||
|
@ -326,8 +326,7 @@ Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
|
||||
return mirror.TensorValue(t);
|
||||
}
|
||||
|
||||
TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU(
|
||||
const EagerContext& ctx) const {
|
||||
VariantDevice TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
|
||||
if (VariantDeviceIsCustom(device_)) {
|
||||
return device_;
|
||||
} else {
|
||||
@ -788,16 +787,15 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx,
|
||||
return status;
|
||||
}
|
||||
|
||||
bool VariantDeviceIsCustom(
|
||||
absl::variant<Device*, CustomDevice*> variant_device) {
|
||||
bool VariantDeviceIsCustom(VariantDevice variant_device) {
|
||||
return variant_device.index() != 0;
|
||||
}
|
||||
|
||||
string VariantDeviceName(absl::variant<Device*, CustomDevice*> device) {
|
||||
string VariantDeviceName(VariantDevice device) {
|
||||
return absl::visit([](auto* device) { return device->name(); }, device);
|
||||
}
|
||||
|
||||
string VariantDeviceDebugString(absl::variant<Device*, CustomDevice*> device) {
|
||||
string VariantDeviceDebugString(VariantDevice device) {
|
||||
if (device == kVariantDeviceNull) {
|
||||
return "[]";
|
||||
} else if (VariantDeviceIsCustom(device)) {
|
||||
|
@ -55,11 +55,6 @@ class EagerContext;
|
||||
// (unrelated to python TensorHandle).
|
||||
class TensorHandle : public AbstractTensorHandleInterface,
|
||||
public core::RefCounted {
|
||||
// Custom devices do many of the same things as physical Devices, but have a
|
||||
// much more restricted interface. We pass around ambiguous pointers since
|
||||
// TensorHandles may be placed either on custom or physical devices.
|
||||
using VariantDevice = absl::variant<Device*, CustomDevice*>;
|
||||
|
||||
// TensorHandle for dtype != DT_RESOURCE
|
||||
TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
||||
Device* resource_device, EagerContext* ctx);
|
||||
@ -291,18 +286,17 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
||||
};
|
||||
|
||||
// Checks whether a VariantDevice contains a custom device.
|
||||
bool VariantDeviceIsCustom(absl::variant<Device*, CustomDevice*> device);
|
||||
bool VariantDeviceIsCustom(VariantDevice device);
|
||||
|
||||
// Wraps device->name() or CustomDevice->name().
|
||||
string VariantDeviceName(absl::variant<Device*, CustomDevice*> device);
|
||||
string VariantDeviceName(VariantDevice device);
|
||||
|
||||
// Wraps device->DebugString() or CustomDevice->name().
|
||||
string VariantDeviceDebugString(absl::variant<Device*, CustomDevice*> device);
|
||||
string VariantDeviceDebugString(VariantDevice device);
|
||||
|
||||
// Indicates either HostCPU or an unset physical device. We never set a null
|
||||
// CustomDevice*.
|
||||
const absl::variant<Device*, CustomDevice*> kVariantDeviceNull =
|
||||
static_cast<Device*>(nullptr);
|
||||
const VariantDevice kVariantDeviceNull = static_cast<Device*>(nullptr);
|
||||
|
||||
// Returns the device backing the resource. Else, returns nullptr.
|
||||
Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);
|
||||
|
@ -76,7 +76,7 @@ Status RemoteMgr::GetMirroredResourceShape(
|
||||
Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
|
||||
int64* op_id, int32* output_num) {
|
||||
// TODO(allenl): Consider supporting remote handles on custom devices.
|
||||
absl::variant<Device*, CustomDevice*> device = handle->device();
|
||||
VariantDevice device = handle->device();
|
||||
if (VariantDeviceIsCustom(device)) {
|
||||
return errors::Unimplemented(
|
||||
"Custom devices and remote execution are currently not supported "
|
||||
|
Loading…
Reference in New Issue
Block a user