Use the same CompositeDevice name on remote workers as the one on a client.

PiperOrigin-RevId: 317702206
Change-Id: I7068efb25eb930252f89a167108ed59c69c2078f
This commit is contained in:
Yujing Zhang 2020-06-22 11:48:57 -07:00 committed by TensorFlower Gardener
parent 8cf9784629
commit 39d080e8b9
13 changed files with 99 additions and 23 deletions

View File

@ -25,6 +25,16 @@ const char* const kCompositeDeviceType = "COMPOSITE";
std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
const std::vector<string>& underlying_devices, const int unique_device_id,
const DeviceNameUtils::ParsedName& host_name, Status* status) {
DeviceNameUtils::ParsedName parsed_name = host_name;
parsed_name.type = kCompositeDeviceType;
parsed_name.id = unique_device_id;
const string device_name = DeviceNameUtils::ParsedNameToString(parsed_name);
return CompositeDevice::MakeDevice(underlying_devices, device_name, status);
}
std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
const std::vector<string>& underlying_devices, const string& device_name,
Status* status) {
if (underlying_devices.empty()) {
status->Update(
errors::InvalidArgument("underlying_devices should not be empty."));
@ -63,13 +73,8 @@ std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
}
}
DeviceNameUtils::ParsedName parsed_composite_name = host_name;
DeviceAttributes device_attributes;
parsed_composite_name.type = kCompositeDeviceType;
parsed_composite_name.id = unique_device_id;
const string composite_name =
DeviceNameUtils::ParsedNameToString(parsed_composite_name);
device_attributes.set_name(composite_name);
device_attributes.set_name(device_name);
device_attributes.set_device_type(kCompositeDeviceType);
return absl::WrapUnique(

View File

@ -48,6 +48,11 @@ class CompositeDevice : public Device {
const std::vector<string>& underlying_devices, const int unique_device_id,
const DeviceNameUtils::ParsedName& host_name, Status* status);
// Helper for creating a CompositeDevice with the given device name.
static std::unique_ptr<CompositeDevice> MakeDevice(
const std::vector<string>& underlying_devices, const string& device_name,
Status* status);
private:
CompositeDevice(const DeviceAttributes& device_attributes,
const std::vector<string>& underlying_devices)

View File

@ -80,4 +80,20 @@ TEST(CompositeDeviceTest, Basic) {
}
}
TEST(CompositeDeviceTest, DeviceName) {
const string composite_device_name =
"/job:localhost/replica:0/task:0/device:CPU:10";
std::vector<string> underlying_devices;
underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:1");
Status status;
std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice(underlying_devices, composite_device_name,
&status);
TF_ASSERT_OK(status);
EXPECT_EQ(composite_device->name(), composite_device_name);
EXPECT_EQ(composite_device->device_type(), kCompositeDeviceType);
EXPECT_EQ(underlying_devices, *composite_device->underlying_devices());
}
} // namespace tensorflow

View File

@ -893,7 +893,7 @@ Status EagerContext::FindDeviceFromName(const char* device_name,
}
Status EagerContext::FindCompositeDeviceFromName(
const char* device_name, CompositeDevice** device) const {
StringPiece device_name, CompositeDevice** device) const {
tf_shared_lock l(composite_devices_mu_);
for (const auto& d : composite_devices_) {
if (d.second->name() == device_name) {
@ -939,8 +939,13 @@ Status EagerContext::RegisterCustomDevice(
}
Status EagerContext::FindOrCreateCompositeDevice(
const std::vector<string>& underlying_devices,
const std::vector<string>& underlying_devices, const string& device_name,
CompositeDevice** composite_device) {
if (!device_name.empty() &&
FindCompositeDeviceFromName(device_name, composite_device).ok()) {
return Status::OK();
}
const uint64 hash_key = Fingerprint64(absl::StrJoin(underlying_devices, ","));
mutex_lock l(composite_devices_mu_);
@ -951,11 +956,16 @@ Status EagerContext::FindOrCreateCompositeDevice(
}
Status s;
// Create a CompositeDevice on the same task as the host CPU, in order to
// trigger packed TensorHandle copy from a client to a remote worker.
auto device =
CompositeDevice::MakeDevice(underlying_devices, composite_devices_.size(),
HostCPU()->parsed_name(), &s);
std::unique_ptr<CompositeDevice> device;
if (device_name.empty()) {
// Create a CompositeDevice on the same task as the host CPU, in order to
// trigger packed TensorHandle copy from a client to a remote worker.
device = CompositeDevice::MakeDevice(underlying_devices,
composite_devices_.size(),
HostCPU()->parsed_name(), &s);
} else {
device = CompositeDevice::MakeDevice(underlying_devices, device_name, &s);
}
TF_RETURN_IF_ERROR(s);
*composite_device = device.get();
pflr_->AddCompositeDevice(*composite_device);

View File

@ -486,7 +486,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
Status FindDeviceFromName(const char* device_name, Device** device) const;
Status FindCompositeDeviceFromName(const char* device_name,
Status FindCompositeDeviceFromName(StringPiece device_name,
CompositeDevice** device) const;
Status FindCustomDeviceFromName(const string& device_name,
@ -495,9 +495,10 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
Status RegisterCustomDevice(const string& name,
std::unique_ptr<CustomDevice> device);
// Find or create a composite device with the given `underlying_devices`.
// Find or create a composite device with the given `underlying_devices` and
// `device_name` (if not empty).
Status FindOrCreateCompositeDevice(
const std::vector<string>& underlying_devices,
const std::vector<string>& underlying_devices, const string& device_name,
CompositeDevice** composite_device);
bool OnSameTask(const Device* first, const Device* second) const;

View File

@ -177,6 +177,7 @@ TEST_F(EagerContextTest, CompositeDevice) {
"/job:worker/replica:0/task:0/device:CPU:1"};
CompositeDevice* composite_device_0 = nullptr;
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
/*device_name=*/"",
&composite_device_0));
EXPECT_EQ(composite_device_0->name(),
"/job:localhost/replica:0/task:0/device:COMPOSITE:0");
@ -186,11 +187,13 @@ TEST_F(EagerContextTest, CompositeDevice) {
EXPECT_EQ(device, composite_device_0);
CompositeDevice* composite_device_1 = nullptr;
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
/*device_name=*/"",
&composite_device_1));
EXPECT_EQ(composite_device_1, composite_device_0);
underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:2");
CompositeDevice* composite_device_2 = nullptr;
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
/*device_name=*/"",
&composite_device_2));
EXPECT_EQ(composite_device_2->name(),
"/job:localhost/replica:0/task:0/device:COMPOSITE:1");
@ -202,5 +205,33 @@ TEST_F(EagerContextTest, CompositeDevice) {
"/job:localhost/replica:0/task:0/device:COMPOSITE:2", &device)));
}
TEST_F(EagerContextTest, CompositeDeviceWithGivenName) {
InitContext(SessionOptions(), DEVICE_PLACEMENT_EXPLICIT);
const std::vector<string> underlying_devices_0 = {
"/job:worker/replica:0/task:0/device:CPU:0",
"/job:worker/replica:0/task:0/device:CPU:1"};
const string composite_device_name =
"/job:worker1/replica:0/task:0/device:COMPOSITE:5";
// Create a CompositeDevice with the given name.
CompositeDevice* composite_device_0 = nullptr;
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(
underlying_devices_0, composite_device_name, &composite_device_0));
EXPECT_EQ(composite_device_0->name(), composite_device_name);
CompositeDevice* device = nullptr;
TF_EXPECT_OK(
context()->FindCompositeDeviceFromName(composite_device_name, &device));
EXPECT_EQ(device, composite_device_0);
std::vector<string> underlying_devices_1 = {
"/job:worker/replica:0/task:0/device:CPU:1",
"/job:worker/replica:0/task:0/device:CPU:2"};
// Find a CompositeDevice with the given name.
CompositeDevice* composite_device_1 = nullptr;
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(
underlying_devices_1, composite_device_0->name(), &composite_device_1));
EXPECT_EQ(composite_device_1, composite_device_0);
}
} // namespace
} // namespace tensorflow

View File

@ -415,7 +415,7 @@ Status GetOrCreateKernelAndDevice(
TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
input_dev_ptrs.push_back(input_device);
CompositeDevice* composite_device = nullptr;
if (ctx.FindCompositeDeviceFromName(input_device->name().c_str(),
if (ctx.FindCompositeDeviceFromName(input_device->name(),
&composite_device)
.ok()) {
composite_devices[input_device->name()] =

View File

@ -320,6 +320,7 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
const tensorflow::DataType dtype,
const tensorflow::TensorShape& shape,
const string& device_name,
EagerContext* ctx,
TensorHandle** packed_handle) {
if (handles.empty()) {
@ -343,8 +344,8 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
}
CompositeDevice* composite_device = nullptr;
TF_RETURN_IF_ERROR(
ctx->FindOrCreateCompositeDevice(devices, &composite_device));
TF_RETURN_IF_ERROR(ctx->FindOrCreateCompositeDevice(devices, device_name,
&composite_device));
*packed_handle =
new TensorHandle(std::move(handles), composite_device, dtype, shape, ctx);
(*packed_handle)->SetResourceHandleInfo(std::move(resource_handle_info));
@ -363,8 +364,8 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
tensorflow::DataType dtype = handles.at(0)->dtype;
tensorflow::TensorShape shape;
TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
return CreatePackedHandle(std::move(handles), dtype, shape, ctx,
packed_handle);
return CreatePackedHandle(std::move(handles), dtype, shape,
/*device_name*/ "", ctx, packed_handle);
}
TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,

View File

@ -94,7 +94,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle,
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
const tensorflow::DataType dtype,
const tensorflow::TensorShape& shape,
EagerContext* ctx,
const string& device_name, EagerContext* ctx,
TensorHandle** packed_handle);
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
EagerContext* ctx,

View File

@ -685,7 +685,7 @@ Status EagerServiceImpl::SendPackedHandle(
// Create a unshaped packed TensorHandle.
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
eager_context, &packed_handle));
send_packed_handle.device_name(), eager_context, &packed_handle));
for (auto* h : handles) {
// Unref handle since it has a ref in the packed handle now.

View File

@ -1071,6 +1071,8 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0";
const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0";
const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0";
const string composite_device =
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
uint64 context_id = random::New64();
CreateContextRequest request;
@ -1125,6 +1127,8 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
EXPECT_EQ(absl::get<Device*>(packed_handle->device())->name(),
composite_device);
TensorHandle* handle0 = nullptr;
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));

View File

@ -298,6 +298,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
const Device* target_device, EagerContext* ctx,
SendPackedHandleOp* op) {
op->set_op_id(op_id);
op->set_device_name(VariantDeviceName(packed_handle->DeviceOrHostCPU(*ctx)));
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
TensorHandle* h = nullptr;
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));

View File

@ -258,6 +258,8 @@ message SendPackedHandleOp {
}
repeated Handle handles = 2;
string device_name = 3;
}
////////////////////////////////////////////////////////////////////////////////