Use the same CompositeDevice name on remote workers as the one on a client.
PiperOrigin-RevId: 317702206 Change-Id: I7068efb25eb930252f89a167108ed59c69c2078f
This commit is contained in:
parent
8cf9784629
commit
39d080e8b9
@ -25,6 +25,16 @@ const char* const kCompositeDeviceType = "COMPOSITE";
|
||||
std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
||||
const std::vector<string>& underlying_devices, const int unique_device_id,
|
||||
const DeviceNameUtils::ParsedName& host_name, Status* status) {
|
||||
DeviceNameUtils::ParsedName parsed_name = host_name;
|
||||
parsed_name.type = kCompositeDeviceType;
|
||||
parsed_name.id = unique_device_id;
|
||||
const string device_name = DeviceNameUtils::ParsedNameToString(parsed_name);
|
||||
return CompositeDevice::MakeDevice(underlying_devices, device_name, status);
|
||||
}
|
||||
|
||||
std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
||||
const std::vector<string>& underlying_devices, const string& device_name,
|
||||
Status* status) {
|
||||
if (underlying_devices.empty()) {
|
||||
status->Update(
|
||||
errors::InvalidArgument("underlying_devices should not be empty."));
|
||||
@ -63,13 +73,8 @@ std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
||||
}
|
||||
}
|
||||
|
||||
DeviceNameUtils::ParsedName parsed_composite_name = host_name;
|
||||
DeviceAttributes device_attributes;
|
||||
parsed_composite_name.type = kCompositeDeviceType;
|
||||
parsed_composite_name.id = unique_device_id;
|
||||
const string composite_name =
|
||||
DeviceNameUtils::ParsedNameToString(parsed_composite_name);
|
||||
device_attributes.set_name(composite_name);
|
||||
device_attributes.set_name(device_name);
|
||||
device_attributes.set_device_type(kCompositeDeviceType);
|
||||
|
||||
return absl::WrapUnique(
|
||||
|
@ -48,6 +48,11 @@ class CompositeDevice : public Device {
|
||||
const std::vector<string>& underlying_devices, const int unique_device_id,
|
||||
const DeviceNameUtils::ParsedName& host_name, Status* status);
|
||||
|
||||
// Helper for creating a CompositeDevice with the given device name.
|
||||
static std::unique_ptr<CompositeDevice> MakeDevice(
|
||||
const std::vector<string>& underlying_devices, const string& device_name,
|
||||
Status* status);
|
||||
|
||||
private:
|
||||
CompositeDevice(const DeviceAttributes& device_attributes,
|
||||
const std::vector<string>& underlying_devices)
|
||||
|
@ -80,4 +80,20 @@ TEST(CompositeDeviceTest, Basic) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompositeDeviceTest, DeviceName) {
|
||||
const string composite_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:10";
|
||||
std::vector<string> underlying_devices;
|
||||
underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:1");
|
||||
Status status;
|
||||
std::unique_ptr<CompositeDevice> composite_device =
|
||||
CompositeDevice::MakeDevice(underlying_devices, composite_device_name,
|
||||
&status);
|
||||
TF_ASSERT_OK(status);
|
||||
EXPECT_EQ(composite_device->name(), composite_device_name);
|
||||
EXPECT_EQ(composite_device->device_type(), kCompositeDeviceType);
|
||||
EXPECT_EQ(underlying_devices, *composite_device->underlying_devices());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -893,7 +893,7 @@ Status EagerContext::FindDeviceFromName(const char* device_name,
|
||||
}
|
||||
|
||||
Status EagerContext::FindCompositeDeviceFromName(
|
||||
const char* device_name, CompositeDevice** device) const {
|
||||
StringPiece device_name, CompositeDevice** device) const {
|
||||
tf_shared_lock l(composite_devices_mu_);
|
||||
for (const auto& d : composite_devices_) {
|
||||
if (d.second->name() == device_name) {
|
||||
@ -939,8 +939,13 @@ Status EagerContext::RegisterCustomDevice(
|
||||
}
|
||||
|
||||
Status EagerContext::FindOrCreateCompositeDevice(
|
||||
const std::vector<string>& underlying_devices,
|
||||
const std::vector<string>& underlying_devices, const string& device_name,
|
||||
CompositeDevice** composite_device) {
|
||||
if (!device_name.empty() &&
|
||||
FindCompositeDeviceFromName(device_name, composite_device).ok()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const uint64 hash_key = Fingerprint64(absl::StrJoin(underlying_devices, ","));
|
||||
|
||||
mutex_lock l(composite_devices_mu_);
|
||||
@ -951,11 +956,16 @@ Status EagerContext::FindOrCreateCompositeDevice(
|
||||
}
|
||||
|
||||
Status s;
|
||||
// Create a CompositeDevice on the same task as the host CPU, in order to
|
||||
// trigger packed TensorHandle copy from a client to a remote worker.
|
||||
auto device =
|
||||
CompositeDevice::MakeDevice(underlying_devices, composite_devices_.size(),
|
||||
HostCPU()->parsed_name(), &s);
|
||||
std::unique_ptr<CompositeDevice> device;
|
||||
if (device_name.empty()) {
|
||||
// Create a CompositeDevice on the same task as the host CPU, in order to
|
||||
// trigger packed TensorHandle copy from a client to a remote worker.
|
||||
device = CompositeDevice::MakeDevice(underlying_devices,
|
||||
composite_devices_.size(),
|
||||
HostCPU()->parsed_name(), &s);
|
||||
} else {
|
||||
device = CompositeDevice::MakeDevice(underlying_devices, device_name, &s);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
*composite_device = device.get();
|
||||
pflr_->AddCompositeDevice(*composite_device);
|
||||
|
@ -486,7 +486,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
|
||||
Status FindDeviceFromName(const char* device_name, Device** device) const;
|
||||
|
||||
Status FindCompositeDeviceFromName(const char* device_name,
|
||||
Status FindCompositeDeviceFromName(StringPiece device_name,
|
||||
CompositeDevice** device) const;
|
||||
|
||||
Status FindCustomDeviceFromName(const string& device_name,
|
||||
@ -495,9 +495,10 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
Status RegisterCustomDevice(const string& name,
|
||||
std::unique_ptr<CustomDevice> device);
|
||||
|
||||
// Find or create a composite device with the given `underlying_devices`.
|
||||
// Find or create a composite device with the given `underlying_devices` and
|
||||
// `device_name` (if not empty).
|
||||
Status FindOrCreateCompositeDevice(
|
||||
const std::vector<string>& underlying_devices,
|
||||
const std::vector<string>& underlying_devices, const string& device_name,
|
||||
CompositeDevice** composite_device);
|
||||
|
||||
bool OnSameTask(const Device* first, const Device* second) const;
|
||||
|
@ -177,6 +177,7 @@ TEST_F(EagerContextTest, CompositeDevice) {
|
||||
"/job:worker/replica:0/task:0/device:CPU:1"};
|
||||
CompositeDevice* composite_device_0 = nullptr;
|
||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||
/*device_name=*/"",
|
||||
&composite_device_0));
|
||||
EXPECT_EQ(composite_device_0->name(),
|
||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:0");
|
||||
@ -186,11 +187,13 @@ TEST_F(EagerContextTest, CompositeDevice) {
|
||||
EXPECT_EQ(device, composite_device_0);
|
||||
CompositeDevice* composite_device_1 = nullptr;
|
||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||
/*device_name=*/"",
|
||||
&composite_device_1));
|
||||
EXPECT_EQ(composite_device_1, composite_device_0);
|
||||
underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:2");
|
||||
CompositeDevice* composite_device_2 = nullptr;
|
||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||
/*device_name=*/"",
|
||||
&composite_device_2));
|
||||
EXPECT_EQ(composite_device_2->name(),
|
||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:1");
|
||||
@ -202,5 +205,33 @@ TEST_F(EagerContextTest, CompositeDevice) {
|
||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:2", &device)));
|
||||
}
|
||||
|
||||
TEST_F(EagerContextTest, CompositeDeviceWithGivenName) {
|
||||
InitContext(SessionOptions(), DEVICE_PLACEMENT_EXPLICIT);
|
||||
const std::vector<string> underlying_devices_0 = {
|
||||
"/job:worker/replica:0/task:0/device:CPU:0",
|
||||
"/job:worker/replica:0/task:0/device:CPU:1"};
|
||||
const string composite_device_name =
|
||||
"/job:worker1/replica:0/task:0/device:COMPOSITE:5";
|
||||
// Create a CompositeDevice with the given name.
|
||||
CompositeDevice* composite_device_0 = nullptr;
|
||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(
|
||||
underlying_devices_0, composite_device_name, &composite_device_0));
|
||||
EXPECT_EQ(composite_device_0->name(), composite_device_name);
|
||||
|
||||
CompositeDevice* device = nullptr;
|
||||
TF_EXPECT_OK(
|
||||
context()->FindCompositeDeviceFromName(composite_device_name, &device));
|
||||
EXPECT_EQ(device, composite_device_0);
|
||||
|
||||
std::vector<string> underlying_devices_1 = {
|
||||
"/job:worker/replica:0/task:0/device:CPU:1",
|
||||
"/job:worker/replica:0/task:0/device:CPU:2"};
|
||||
// Find a CompositeDevice with the given name.
|
||||
CompositeDevice* composite_device_1 = nullptr;
|
||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(
|
||||
underlying_devices_1, composite_device_0->name(), &composite_device_1));
|
||||
EXPECT_EQ(composite_device_1, composite_device_0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -415,7 +415,7 @@ Status GetOrCreateKernelAndDevice(
|
||||
TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
|
||||
input_dev_ptrs.push_back(input_device);
|
||||
CompositeDevice* composite_device = nullptr;
|
||||
if (ctx.FindCompositeDeviceFromName(input_device->name().c_str(),
|
||||
if (ctx.FindCompositeDeviceFromName(input_device->name(),
|
||||
&composite_device)
|
||||
.ok()) {
|
||||
composite_devices[input_device->name()] =
|
||||
|
@ -320,6 +320,7 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
|
||||
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
const tensorflow::DataType dtype,
|
||||
const tensorflow::TensorShape& shape,
|
||||
const string& device_name,
|
||||
EagerContext* ctx,
|
||||
TensorHandle** packed_handle) {
|
||||
if (handles.empty()) {
|
||||
@ -343,8 +344,8 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
}
|
||||
|
||||
CompositeDevice* composite_device = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->FindOrCreateCompositeDevice(devices, &composite_device));
|
||||
TF_RETURN_IF_ERROR(ctx->FindOrCreateCompositeDevice(devices, device_name,
|
||||
&composite_device));
|
||||
*packed_handle =
|
||||
new TensorHandle(std::move(handles), composite_device, dtype, shape, ctx);
|
||||
(*packed_handle)->SetResourceHandleInfo(std::move(resource_handle_info));
|
||||
@ -363,8 +364,8 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
tensorflow::DataType dtype = handles.at(0)->dtype;
|
||||
tensorflow::TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
|
||||
return CreatePackedHandle(std::move(handles), dtype, shape, ctx,
|
||||
packed_handle);
|
||||
return CreatePackedHandle(std::move(handles), dtype, shape,
|
||||
/*device_name*/ "", ctx, packed_handle);
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
|
||||
|
@ -94,7 +94,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle,
|
||||
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
const tensorflow::DataType dtype,
|
||||
const tensorflow::TensorShape& shape,
|
||||
EagerContext* ctx,
|
||||
const string& device_name, EagerContext* ctx,
|
||||
TensorHandle** packed_handle);
|
||||
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
EagerContext* ctx,
|
||||
|
@ -685,7 +685,7 @@ Status EagerServiceImpl::SendPackedHandle(
|
||||
// Create a unshaped packed TensorHandle.
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
|
||||
std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
|
||||
eager_context, &packed_handle));
|
||||
send_packed_handle.device_name(), eager_context, &packed_handle));
|
||||
|
||||
for (auto* h : handles) {
|
||||
// Unref handle since it has a ref in the packed handle now.
|
||||
|
@ -1071,6 +1071,8 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
|
||||
const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
const string composite_device =
|
||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
|
||||
|
||||
uint64 context_id = random::New64();
|
||||
CreateContextRequest request;
|
||||
@ -1125,6 +1127,8 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
|
||||
|
||||
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
|
||||
EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
|
||||
EXPECT_EQ(absl::get<Device*>(packed_handle->device())->name(),
|
||||
composite_device);
|
||||
|
||||
TensorHandle* handle0 = nullptr;
|
||||
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
|
||||
|
@ -298,6 +298,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
|
||||
const Device* target_device, EagerContext* ctx,
|
||||
SendPackedHandleOp* op) {
|
||||
op->set_op_id(op_id);
|
||||
op->set_device_name(VariantDeviceName(packed_handle->DeviceOrHostCPU(*ctx)));
|
||||
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
|
||||
TensorHandle* h = nullptr;
|
||||
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
|
||||
|
@ -258,6 +258,8 @@ message SendPackedHandleOp {
|
||||
}
|
||||
|
||||
repeated Handle handles = 2;
|
||||
|
||||
string device_name = 3;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
Loading…
x
Reference in New Issue
Block a user