Enable eager tensor mirroring all the time

This greatly improve performance by reducing any repeated host to device
copies needed for op execution.

PiperOrigin-RevId: 297039608
Change-Id: Ic5f0d5c4ff15f1af096ba0ea20e96af76e53a6d8
This commit is contained in:
Gaurav Jain 2020-02-24 21:49:08 -08:00 committed by TensorFlower Gardener
parent e10fba7de3
commit 7fbfccbd30
2 changed files with 12 additions and 26 deletions

View File

@ -369,7 +369,7 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
void TensorHandleSilentCopy(bool async,
TFE_ContextDevicePlacementPolicy global_policy,
TFE_ContextDevicePlacementPolicy thread_policy,
bool mirror, bool cpu_op) {
bool cpu_op) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -392,12 +392,6 @@ void TensorHandleSilentCopy(bool async,
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
if (mirror) {
TFE_TensorHandleEnableImplicitMirroring(hcpu, status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_TensorHandleEnableImplicitMirroring(hgpu, status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
}
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
if (cpu_op) {
@ -424,7 +418,7 @@ void TensorHandleSilentCopy(bool async,
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
if (mirror) {
if (!async) {
// The input handles should never change since they have been mirrored.
ASSERT_EQ(op->GetInput(0), arg0);
ASSERT_EQ(op->GetInput(1), arg1);
@ -454,27 +448,19 @@ void TensorHandleSilentCopy(bool async,
}
TEST(CAPI, TensorHandleSilentCopy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, false, false);
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, false, false);
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT, false, false);
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT, false, false);
}
TEST(CAPI, TensorHandleMirrorCopy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, true, false);
}
TEST(CAPI, TensorHandleMirrorCopyCpu) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, true, true);
TFE_DEVICE_PLACEMENT_SILENT, false);
}
void SetAndGetOpDevices(bool async) {

View File

@ -143,7 +143,7 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
ctx_(ctx),
is_remote_(false),
is_async_(false),
implicit_mirroring_(false),
implicit_mirroring_(true),
is_ready_(true),
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this
@ -164,7 +164,7 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
ctx_(ctx),
is_remote_(false),
is_async_(false),
implicit_mirroring_(false),
implicit_mirroring_(true),
is_ready_(true),
handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
tensor_handle_data_(std::move(t)) {
@ -185,7 +185,7 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
ctx_(ctx),
is_remote_(false),
is_async_(false),
implicit_mirroring_(false),
implicit_mirroring_(true),
is_ready_(true),
tensor_handle_data_(std::move(t)) {
// TODO(allenl): Figure out a better op_device story for custom devices,
@ -220,7 +220,7 @@ TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
ctx_(ctx),
is_remote_(false),
is_async_(async),
implicit_mirroring_(false),
implicit_mirroring_(true),
is_ready_(!async),
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating empty Local TensorHandle: " << this
@ -260,7 +260,7 @@ TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
ctx_(ctx),
is_remote_(true),
is_async_(false),
implicit_mirroring_(false),
implicit_mirroring_(true),
is_ready_(true),
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Remote TensorHandle: " << this
@ -297,7 +297,7 @@ TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
ctx_(ctx),
is_remote_(true),
is_async_(true),
implicit_mirroring_(false),
implicit_mirroring_(true),
is_ready_(false),
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this