Add functions that return device type and ID for eager.
This addition enables more efficient device handling in S4TF without needing to parse the full device string. As support for devices beyond TF eager are added, this info is needed more often and has a bigger impact on performance. Partial fix for https://github.com/tensorflow/swift/issues/524. PiperOrigin-RevId: 337696655 Change-Id: Ifb576d37c765cced2329b77e0cebb591d8d3a46c
This commit is contained in:
parent
0f9acc15ba
commit
9f2b92b4e9
tensorflow
@ -638,3 +638,19 @@ void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
|||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
|
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
|
if (h == nullptr) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensorflow::unwrap(h)->DeviceType(&status->status);
|
||||||
|
}
|
||||||
|
|
||||||
|
int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
|
if (h == nullptr) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return tensorflow::unwrap(h)->DeviceId(&status->status);
|
||||||
|
}
|
||||||
|
@ -553,6 +553,14 @@ TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
|
|||||||
unsigned char enable,
|
unsigned char enable,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Returns the device type of the operation that produced `h`.
|
||||||
|
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
|
||||||
|
TFE_TensorHandle* h, TF_Status* status);
|
||||||
|
|
||||||
|
// Returns the device ID of the operation that produced `h`.
|
||||||
|
TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -411,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
|
|||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TensorHandleNullptr) {
|
||||||
|
TFE_TensorHandle* h = nullptr;
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
|
||||||
|
const char* device_type = TFE_TensorHandleDeviceType(h, status.get());
|
||||||
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||||
|
ASSERT_EQ(device_type, nullptr);
|
||||||
|
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||||
|
|
||||||
|
TF_SetStatus(status.get(), TF_OK, "");
|
||||||
|
|
||||||
|
int device_id = TFE_TensorHandleDeviceID(h, status.get());
|
||||||
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||||
|
ASSERT_EQ(device_id, -1);
|
||||||
|
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TensorHandleDevices) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||||
|
const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
|
||||||
|
int device_id = TFE_TensorHandleDeviceID(hcpu, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(0, device_id) << device_id;
|
||||||
|
|
||||||
|
// Disable the test if no GPU is present.
|
||||||
|
string gpu_device_name;
|
||||||
|
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||||
|
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||||
|
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
TFE_Op* shape_op = ShapeOp(ctx, hgpu);
|
||||||
|
TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_TensorHandle* retvals[1];
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
device_type = TFE_TensorHandleDeviceType(retvals[0], status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
|
||||||
|
|
||||||
|
device_id = TFE_TensorHandleDeviceID(retvals[0], status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(0, device_id) << device_id;
|
||||||
|
|
||||||
|
TFE_DeleteOp(shape_op);
|
||||||
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
|
TFE_DeleteTensorHandle(hgpu);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(hcpu);
|
||||||
|
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||||
|
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TFE_DeleteExecutor(executor);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TensorHandleDefaults) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx);
|
||||||
|
const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
|
||||||
|
int device_id = TFE_TensorHandleDeviceID(h_default, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(0, device_id) << device_id;
|
||||||
|
|
||||||
|
TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice(
|
||||||
|
h_default, ctx, "/device:CPU:0", status.get());
|
||||||
|
const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu;
|
||||||
|
int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(0, device_id_cpu) << device_id_cpu;
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(h_default);
|
||||||
|
TFE_DeleteTensorHandle(h_cpu);
|
||||||
|
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||||
|
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TFE_DeleteExecutor(executor);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -44,6 +44,10 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
|
|||||||
virtual const char* DeviceName(Status* status) const = 0;
|
virtual const char* DeviceName(Status* status) const = 0;
|
||||||
// Returns the device where the tensor was placed.
|
// Returns the device where the tensor was placed.
|
||||||
virtual const char* BackingDeviceName(Status* status) const = 0;
|
virtual const char* BackingDeviceName(Status* status) const = 0;
|
||||||
|
// Returns the device type which created the handle.
|
||||||
|
virtual const char* DeviceType(Status* status) const = 0;
|
||||||
|
// Returns the device ID which created the handle.
|
||||||
|
virtual int DeviceId(Status* status) const = 0;
|
||||||
// Returns a tensor for the handle. If tensor is remote, it will be copied.
|
// Returns a tensor for the handle. If tensor is remote, it will be copied.
|
||||||
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
|
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
|
||||||
|
|
||||||
|
@ -1116,6 +1116,28 @@ const char* TensorHandle::BackingDeviceName(Status* status) const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* TensorHandle::DeviceType(Status* status) const {
|
||||||
|
if (VariantDeviceIsCustom(device())) {
|
||||||
|
status->Update(
|
||||||
|
tensorflow::errors::Unimplemented("Custom device unsupported"));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
status->Update(WaitUnknownDevice());
|
||||||
|
tensorflow::Device* d = op_device();
|
||||||
|
return (d == nullptr) ? "CPU" : d->parsed_name().type.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
int TensorHandle::DeviceId(Status* status) const {
|
||||||
|
if (VariantDeviceIsCustom(device())) {
|
||||||
|
status->Update(
|
||||||
|
tensorflow::errors::Unimplemented("Custom device unsupported"));
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
status->Update(WaitUnknownDevice());
|
||||||
|
tensorflow::Device* d = op_device();
|
||||||
|
return (d == nullptr) ? 0 : d->parsed_name().id;
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::ImmediateExecutionTensorHandle* TensorHandle::Copy() {
|
tensorflow::ImmediateExecutionTensorHandle* TensorHandle::Copy() {
|
||||||
Ref();
|
Ref();
|
||||||
return this;
|
return this;
|
||||||
|
@ -131,6 +131,8 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
|||||||
|
|
||||||
const char* DeviceName(Status* status) const override;
|
const char* DeviceName(Status* status) const override;
|
||||||
const char* BackingDeviceName(Status* status) const override;
|
const char* BackingDeviceName(Status* status) const override;
|
||||||
|
const char* DeviceType(Status* status) const override;
|
||||||
|
int DeviceId(Status* status) const override;
|
||||||
AbstractTensorInterface* Resolve(Status* status) override;
|
AbstractTensorInterface* Resolve(Status* status) override;
|
||||||
|
|
||||||
ImmediateExecutionTensorHandle* Copy() override;
|
ImmediateExecutionTensorHandle* Copy() override;
|
||||||
|
@ -408,4 +408,63 @@ TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) {
|
|||||||
context->Unref();
|
context->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TensorHandle_DeviceNameTest, OnLocalDevice) {
|
||||||
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
|
devices.emplace_back(
|
||||||
|
CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||||
|
devices.emplace_back(
|
||||||
|
CreateDevice("GPU", "/job:localhost/replica:0/task:0/device:GPU:0"));
|
||||||
|
StaticDeviceMgr local_device_mgr(std::move(devices));
|
||||||
|
auto ctx = new EagerContext(
|
||||||
|
SessionOptions(),
|
||||||
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
||||||
|
false, &local_device_mgr, false, nullptr, nullptr);
|
||||||
|
|
||||||
|
Device* dcpu = local_device_mgr.ListDevices()[0];
|
||||||
|
Device* dgpu = local_device_mgr.ListDevices()[1];
|
||||||
|
tensorflow::DataType dtype = DT_RESOURCE;
|
||||||
|
TensorShape shape = {2};
|
||||||
|
Tensor tcpu(dtype, shape);
|
||||||
|
Tensor tgpu(dtype, shape);
|
||||||
|
Status s;
|
||||||
|
|
||||||
|
TensorHandle* th_cpu =
|
||||||
|
TensorHandle::CreateLocalHandle(std::move(tcpu), dcpu, dcpu, dcpu, ctx);
|
||||||
|
const char* device_name = th_cpu->DeviceName(&s);
|
||||||
|
TF_EXPECT_OK(s);
|
||||||
|
ASSERT_TRUE(absl::StrContains(device_name, "CPU")) << device_name;
|
||||||
|
const char* backing_device_name = th_cpu->BackingDeviceName(&s);
|
||||||
|
TF_EXPECT_OK(s);
|
||||||
|
ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU"))
|
||||||
|
<< backing_device_name;
|
||||||
|
const char* device_type = th_cpu->DeviceType(&s);
|
||||||
|
TF_EXPECT_OK(s);
|
||||||
|
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
|
||||||
|
int device_id = th_cpu->DeviceId(&s);
|
||||||
|
TF_EXPECT_OK(s);
|
||||||
|
ASSERT_EQ(0, device_id) << device_id;
|
||||||
|
|
||||||
|
TensorHandle* th_gpu =
|
||||||
|
TensorHandle::CreateLocalHandle(std::move(tgpu), dgpu, dgpu, dgpu, ctx);
|
||||||
|
device_name = th_gpu->DeviceName(&s);
|
||||||
|
TF_EXPECT_OK(s);
|
||||||
|
ASSERT_TRUE(absl::StrContains(device_name, "GPU")) << device_name;
|
||||||
|
backing_device_name = th_gpu->BackingDeviceName(&s);
|
||||||
|
TF_EXPECT_OK(s);
|
||||||
|
std::cout << "backing_device_name for GPU: " << backing_device_name
|
||||||
|
<< std::endl;
|
||||||
|
ASSERT_TRUE(absl::StrContains(backing_device_name, "GPU"))
|
||||||
|
<< backing_device_name;
|
||||||
|
device_type = th_gpu->DeviceType(&s);
|
||||||
|
TF_EXPECT_OK(s);
|
||||||
|
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
|
||||||
|
device_id = th_gpu->DeviceId(&s);
|
||||||
|
TF_EXPECT_OK(s);
|
||||||
|
ASSERT_EQ(0, device_id) << device_id;
|
||||||
|
|
||||||
|
th_cpu->Unref();
|
||||||
|
th_gpu->Unref();
|
||||||
|
ctx->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user