Add functions that return device type and ID for eager.

This addition enables more efficient device handling in S4TF without needing to parse the full device string. As support for devices beyond TF eager are added, this info is needed more often and has a bigger impact on performance.

Partial fix for https://github.com/tensorflow/swift/issues/524.

PiperOrigin-RevId: 337696655
Change-Id: Ifb576d37c765cced2329b77e0cebb591d8d3a46c
This commit is contained in:
Michelle Casbon 2020-10-17 18:35:58 -07:00 committed by TensorFlower Gardener
parent 0f9acc15ba
commit 9f2b92b4e9
7 changed files with 215 additions and 0 deletions

View File

@ -638,3 +638,19 @@ void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) { TF_Status* status) {
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable); tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
} }
const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return tensorflow::unwrap(h)->DeviceType(&status->status);
}
int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
return tensorflow::unwrap(h)->DeviceId(&status->status);
}

View File

@ -553,6 +553,14 @@ TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
unsigned char enable, unsigned char enable,
TF_Status* status); TF_Status* status);
// Returns the device type of the operation that produced `h`.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
TFE_TensorHandle* h, TF_Status* status);
// Returns the device ID of the operation that produced `h`.
TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -411,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
TEST(CAPI, TensorHandleNullptr) {
TFE_TensorHandle* h = nullptr;
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_type = TFE_TensorHandleDeviceType(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_type, nullptr);
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
int device_id = TFE_TensorHandleDeviceID(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_id, -1);
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
}
TEST(CAPI, TensorHandleDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
int device_id = TFE_TensorHandleDeviceID(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* shape_op = ShapeOp(ctx, hgpu);
TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
device_type = TFE_TensorHandleDeviceType(retvals[0], status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
device_id = TFE_TensorHandleDeviceID(retvals[0], status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
TFE_DeleteOp(shape_op);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
}
TFE_DeleteTensorHandle(hcpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleDefaults) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx);
const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
int device_id = TFE_TensorHandleDeviceID(h_default, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice(
h_default, ctx, "/device:CPU:0", status.get());
const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu;
int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id_cpu) << device_id_cpu;
TFE_DeleteTensorHandle(h_default);
TFE_DeleteTensorHandle(h_cpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -44,6 +44,10 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
virtual const char* DeviceName(Status* status) const = 0; virtual const char* DeviceName(Status* status) const = 0;
// Returns the device where the tensor was placed. // Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(Status* status) const = 0; virtual const char* BackingDeviceName(Status* status) const = 0;
// Returns the device type which created the handle.
virtual const char* DeviceType(Status* status) const = 0;
// Returns the device ID which created the handle.
virtual int DeviceId(Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied. // Returns a tensor for the handle. If tensor is remote, it will be copied.
virtual AbstractTensorInterface* Resolve(Status* status) = 0; virtual AbstractTensorInterface* Resolve(Status* status) = 0;

View File

@ -1116,6 +1116,28 @@ const char* TensorHandle::BackingDeviceName(Status* status) const {
} }
} }
const char* TensorHandle::DeviceType(Status* status) const {
if (VariantDeviceIsCustom(device())) {
status->Update(
tensorflow::errors::Unimplemented("Custom device unsupported"));
return nullptr;
}
status->Update(WaitUnknownDevice());
tensorflow::Device* d = op_device();
return (d == nullptr) ? "CPU" : d->parsed_name().type.c_str();
}
int TensorHandle::DeviceId(Status* status) const {
if (VariantDeviceIsCustom(device())) {
status->Update(
tensorflow::errors::Unimplemented("Custom device unsupported"));
return -1;
}
status->Update(WaitUnknownDevice());
tensorflow::Device* d = op_device();
return (d == nullptr) ? 0 : d->parsed_name().id;
}
tensorflow::ImmediateExecutionTensorHandle* TensorHandle::Copy() { tensorflow::ImmediateExecutionTensorHandle* TensorHandle::Copy() {
Ref(); Ref();
return this; return this;

View File

@ -131,6 +131,8 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
const char* DeviceName(Status* status) const override; const char* DeviceName(Status* status) const override;
const char* BackingDeviceName(Status* status) const override; const char* BackingDeviceName(Status* status) const override;
const char* DeviceType(Status* status) const override;
int DeviceId(Status* status) const override;
AbstractTensorInterface* Resolve(Status* status) override; AbstractTensorInterface* Resolve(Status* status) override;
ImmediateExecutionTensorHandle* Copy() override; ImmediateExecutionTensorHandle* Copy() override;

View File

@ -408,4 +408,63 @@ TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) {
context->Unref(); context->Unref();
} }
TEST(TensorHandle_DeviceNameTest, OnLocalDevice) {
std::vector<std::unique_ptr<Device>> devices;
devices.emplace_back(
CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0"));
devices.emplace_back(
CreateDevice("GPU", "/job:localhost/replica:0/task:0/device:GPU:0"));
StaticDeviceMgr local_device_mgr(std::move(devices));
auto ctx = new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
false, &local_device_mgr, false, nullptr, nullptr);
Device* dcpu = local_device_mgr.ListDevices()[0];
Device* dgpu = local_device_mgr.ListDevices()[1];
tensorflow::DataType dtype = DT_RESOURCE;
TensorShape shape = {2};
Tensor tcpu(dtype, shape);
Tensor tgpu(dtype, shape);
Status s;
TensorHandle* th_cpu =
TensorHandle::CreateLocalHandle(std::move(tcpu), dcpu, dcpu, dcpu, ctx);
const char* device_name = th_cpu->DeviceName(&s);
TF_EXPECT_OK(s);
ASSERT_TRUE(absl::StrContains(device_name, "CPU")) << device_name;
const char* backing_device_name = th_cpu->BackingDeviceName(&s);
TF_EXPECT_OK(s);
ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU"))
<< backing_device_name;
const char* device_type = th_cpu->DeviceType(&s);
TF_EXPECT_OK(s);
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
int device_id = th_cpu->DeviceId(&s);
TF_EXPECT_OK(s);
ASSERT_EQ(0, device_id) << device_id;
TensorHandle* th_gpu =
TensorHandle::CreateLocalHandle(std::move(tgpu), dgpu, dgpu, dgpu, ctx);
device_name = th_gpu->DeviceName(&s);
TF_EXPECT_OK(s);
ASSERT_TRUE(absl::StrContains(device_name, "GPU")) << device_name;
backing_device_name = th_gpu->BackingDeviceName(&s);
TF_EXPECT_OK(s);
std::cout << "backing_device_name for GPU: " << backing_device_name
<< std::endl;
ASSERT_TRUE(absl::StrContains(backing_device_name, "GPU"))
<< backing_device_name;
device_type = th_gpu->DeviceType(&s);
TF_EXPECT_OK(s);
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
device_id = th_gpu->DeviceId(&s);
TF_EXPECT_OK(s);
ASSERT_EQ(0, device_id) << device_id;
th_cpu->Unref();
th_gpu->Unref();
ctx->Unref();
}
} // namespace tensorflow } // namespace tensorflow