Do not call Unprotect on remote inputs
Unprotect should only be called on local handles. In order to test the triggering of forwarding for remote inputs to a function we add an optimization whereby EagerExecute releases the inputs of the eager operation. This enforces that a TFE_Op cannot be reused since the inputs would have been removed. This was technically already true since if the inputs were ever forwarded we should not be re-using the TFE_Op. PiperOrigin-RevId: 306564949 Change-Id: I94bd3a243658277891867802b792a4492ec0a039
This commit is contained in:
parent
4d447f8486
commit
a9f8a9b1c1
@ -129,7 +129,45 @@ void TestRemoteExecute(bool async) {
|
||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
string MatMulFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'MatMulFunction'"
|
||||
" input_arg {"
|
||||
" name: 'a'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" input_arg {"
|
||||
" name: 'b'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'm'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'matmul'"
|
||||
" op: 'MatMul'"
|
||||
" input: 'a'"
|
||||
" input: 'b'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'm'"
|
||||
" value: 'matmul:product'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
@ -169,10 +207,29 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
TFE_Op* matmul = nullptr;
|
||||
if (func) {
|
||||
string function_def = MatMulFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h0_task0, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h1_task2, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
} else {
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
}
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
} else if (!async) {
|
||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
@ -182,12 +239,10 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!async) {
|
||||
if (!remote && !async) {
|
||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(matmul->operation);
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(op->Inputs()[1], remote_arg);
|
||||
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
@ -217,6 +272,9 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
if (func) {
|
||||
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
||||
}
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
@ -227,16 +285,22 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||
TestRemoteExecuteSilentCopies(false, true);
|
||||
TestRemoteExecuteSilentCopies(false, true, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, true);
|
||||
TestRemoteExecuteSilentCopies(true, true, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopies(true, true, true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||
TestRemoteExecuteSilentCopies(false, false);
|
||||
TestRemoteExecuteSilentCopies(false, false, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, false);
|
||||
TestRemoteExecuteSilentCopies(true, false, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopies(true, false, true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
|
@ -78,11 +78,18 @@ void BM_Execute(int iters, int async) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_Op* matmul = TFE_NewOp(ctx, "MatMul", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TFE_OpReset(matmul, "MatMul", nullptr, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
@ -113,11 +120,15 @@ void BM_Execute_Identity(int iters, int async) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* identity = IdentityOp(ctx, m);
|
||||
TFE_Op* identity = TFE_NewOp(ctx, "Identity", status);
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TFE_OpReset(identity, "Identity", nullptr, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(identity, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Execute(identity, &retvals[0], &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
@ -405,6 +416,11 @@ void TensorHandleSilentCopy(bool async,
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
auto cpu_arg = tensorflow::TensorHandleFromInterface(hcpu->handle);
|
||||
auto gpu_arg = tensorflow::TensorHandleFromInterface(hgpu->handle);
|
||||
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
|
||||
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
if (cpu_op) {
|
||||
string cpu_device_name;
|
||||
@ -420,15 +436,8 @@ void TensorHandleSilentCopy(bool async,
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Validate if the input was replaced with a different TensorHandle
|
||||
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
|
||||
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(matmul->operation);
|
||||
|
||||
// The input handles should never change since they have been mirrored.
|
||||
EXPECT_EQ(op->Inputs()[0], arg0);
|
||||
EXPECT_EQ(op->Inputs()[1], arg1);
|
||||
// The CPU handle should have been copied and have a mirror on the GPU
|
||||
ASSERT_TRUE(cpu_arg->HasLocalMirror(gpu_device));
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
@ -626,17 +635,6 @@ void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
|
||||
}
|
||||
|
||||
int num_retvals = 1;
|
||||
|
||||
if (async) {
|
||||
// Enqueue dummy ops so we backlog async execution & actually test async.
|
||||
for (int i = 0; i < 10000; ++i) {
|
||||
TFE_TensorHandle* dummy = nullptr;
|
||||
TFE_Execute(add_op, &dummy, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(dummy);
|
||||
}
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retval = nullptr;
|
||||
TFE_Execute(add_op, &retval, &num_retvals, status);
|
||||
EXPECT_EQ(1, num_retvals);
|
||||
|
@ -596,6 +596,10 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
|
||||
graph_collector, op->GetCancellationManager(),
|
||||
absl::Span<TensorHandle*>(retvals, num_outputs));
|
||||
// Release the inputs from the eager operation since the AsyncExecuteNode
|
||||
// would have taken ownership. This allows the inputs to be forwarded if
|
||||
// possible.
|
||||
op->Clear();
|
||||
// For async mode, execution order will make sure that all
|
||||
// input handles are ready before executing them.
|
||||
// TODO(b/137118203): Consider executing "cheap" kernels inline for
|
||||
@ -609,6 +613,10 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
graph_collector, op->GetCancellationManager(),
|
||||
{retvals, static_cast<size_t>(num_outputs)});
|
||||
s = executor.SyncExecute(&node);
|
||||
// We release the inputs AFTER executing the operation in sync mode since
|
||||
// ExecuteNode does not increment the reference count and thus does not have
|
||||
// ownership of the inputs while executing.
|
||||
op->Clear();
|
||||
}
|
||||
// Since the operation failed, we need to Unref any outputs if they were
|
||||
// allocated.
|
||||
|
@ -449,7 +449,7 @@ Status TensorHandle::NumElements(int64* num_elements) const {
|
||||
Status TensorHandle::Unprotect(const Device* d) {
|
||||
DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
|
||||
|
||||
if (d == absl::get<Device*>(device_)) {
|
||||
if (!IsRemote() && (d == absl::get<Device*>(device_))) {
|
||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
||||
return data.Unprotect();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user