Support copying remote output tensors back from a remote device.

PiperOrigin-RevId: 303771801
Change-Id: I9cbbb1e156ad0ee580b4b56b660756aa911d385a
This commit is contained in:
Yujing Zhang 2020-03-30 10:35:56 -07:00 committed by TensorFlower Gardener
parent a60b5ac5fb
commit 238d1a70a7
8 changed files with 116 additions and 50 deletions

View File

@ -31,12 +31,6 @@ void EagerProcessFunctionLibraryRuntime::RunRemoteDevice(
FunctionLibraryRuntime::Handle local_handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
if (!rets->empty()) {
done(
errors::Unimplemented("Remote outputs are not supported by "
"EagerClusterFunctionLibraryRuntime yet."));
return;
}
parent_->Run(opts, local_handle, args, rets, std::move(done));
}

View File

@ -120,15 +120,11 @@ void EagerClusterFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
FunctionLibraryRuntime::Options opts_copy = opts;
if (!opts_copy.op_id.has_value()) {
opts_copy.op_id = ctx_->RemoteMgr()->NextOpId();
}
std::vector<FunctionArg> function_args;
for (const auto& tensor : args) {
function_args.push_back(tensor);
}
Run(opts_copy, handle, function_args, rets, std::move(done));
std::vector<FunctionArg> function_args;
for (const auto& tensor : args) {
function_args.push_back(tensor);
}
Run(opts, handle, function_args, rets, std::move(done));
}
void EagerClusterFunctionLibraryRuntime::Run(
@ -165,11 +161,6 @@ void EagerClusterFunctionLibraryRuntime::Run(
EagerOperation* op = function_data->op.get();
if (!opts.op_id.has_value()) {
done(
errors::Internal("op_id is not set for remote function: ", op->Name()));
}
eager::EnqueueRequest* request = new eager::EnqueueRequest;
request->set_context_id(context_id_);
eager::Operation* remote_op = request->add_queue()->mutable_operation();
@ -187,7 +178,11 @@ void EagerClusterFunctionLibraryRuntime::Run(
// The remote component function should use the same op_id as its parent
// multi-device function's in order to get the global unique op_id generated
// by the master context.
remote_op->set_id(opts.op_id.value());
if (opts.op_id.has_value()) {
remote_op->set_id(opts.op_id.value());
} else {
remote_op->set_id(kInvalidRemoteOpId);
}
remote_op->set_is_function(true);
remote_op->set_is_component_function(true);
remote_op->set_func_step_id(opts.step_id);
@ -203,15 +198,39 @@ void EagerClusterFunctionLibraryRuntime::Run(
// disabled, Run() returns when the remote function execution completes, which
// might be blocked by a non-enqueued function execution.
EnqueueResponse* response = new EnqueueResponse;
eager_client->EnqueueAsync(request, response,
[op, request, response, done](const Status& s) {
for (auto handle : op->Inputs()) {
handle->Unref();
}
done(s);
delete request;
delete response;
});
eager_client->EnqueueAsync(
request, response,
[op, request, response, rets, done = std::move(done)](const Status& s) {
Status status = s;
auto cleanup = gtl::MakeCleanup([request, response, &status, &done] {
done(status);
delete request;
delete response;
});
for (auto handle : op->Inputs()) {
handle->Unref();
}
if (!status.ok()) {
return;
}
if (response->queue_response_size() != 1) {
status.Update(errors::Internal(
"Expect that the size of response queue equals 1, but got: ",
response->queue_response_size()));
return;
}
for (const auto& tensor_proto : response->queue_response(0).tensor()) {
Tensor t;
if (t.FromProto(tensor_proto)) {
rets->push_back(std::move(t));
} else {
status.Update(errors::Internal("Could not convert tensor proto: ",
tensor_proto.DebugString()));
return;
}
}
});
}
void EagerClusterFunctionLibraryRuntime::CleanUp(

View File

@ -327,6 +327,13 @@ Status EagerServiceImpl::CreateMasterContext(
return Status::OK();
}
Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
const tensorflow::Tensor* t = nullptr;
TF_RETURN_IF_ERROR(handle->Tensor(&t));
t->AsProtoTensorContent(proto);
return Status::OK();
}
Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
const tensorflow::Tensor* t = nullptr;
@ -412,12 +419,21 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
TF_RETURN_IF_ERROR(EagerExecute(op.get(), retvals.data(), &num_retvals));
eager_context->RemoteMgr()->AddOperationOutputs(
absl::MakeSpan(retvals.data(), num_retvals), operation.id());
for (int i = 0; i < num_retvals; i++) {
TF_RETURN_IF_ERROR(
TensorHandleShape(retvals[i], queue_response->add_shape()));
if (operation.id() == kInvalidRemoteOpId) {
// Copy the output tensors back along with the response, since the op id
// is invalid which cannot be added to RemoteMgr.
for (int i = 0; i < num_retvals; i++) {
TF_RETURN_IF_ERROR(
TensorHandleProto(retvals[i], queue_response->add_tensor()));
retvals[i]->Unref();
}
} else {
eager_context->RemoteMgr()->AddOperationOutputs(
absl::MakeSpan(retvals.data(), num_retvals), operation.id());
for (int i = 0; i < num_retvals; i++) {
TF_RETURN_IF_ERROR(
TensorHandleShape(retvals[i], queue_response->add_shape()));
}
}
return Status::OK();

View File

@ -560,13 +560,8 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
/*thread_pool=*/nullptr, eager_cluster_flr_.get());
}
void CheckOutputsAndClose(const int64 op_id) {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* tensor_handle;
TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle));
TF_ASSERT_OK(tensor_handle->Tensor(&t));
auto actual = t->flat<float>();
void CheckOutputTensorAndClose(const Tensor& tensor) {
auto actual = tensor.flat<float>();
EXPECT_EQ(4, actual.size());
EXPECT_EQ(7, actual(0));
EXPECT_EQ(10, actual(1));
@ -581,6 +576,15 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
&close_context_response));
}
void CheckOutputsAndClose(const int64 op_id) {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* tensor_handle;
TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle));
TF_ASSERT_OK(tensor_handle->Tensor(&t));
CheckOutputTensorAndClose(*t);
}
protected:
const string local_device_ = "/job:localhost/replica:0/task:0/device:CPU:0";
const string remote_device_ = "/job:localhost/replica:0/task:1/device:CPU:0";
@ -649,8 +653,9 @@ TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) {
CheckOutputsAndClose(op_id);
}
// Test executes a remote function with a local tensor input.
TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) {
// Test executes a remote function with local input and output tensors.
TEST_F(FunctionWithRemoteInputsTest,
EagerClusterFLRTestWithLocalInputAndOutput) {
Init();
// Instantiate MatMulFunction on remote_device.
FunctionLibraryRuntime::Handle handle;
@ -681,11 +686,9 @@ TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) {
context_id_, RemoteTensorHandleInternal(1, 0), &tensor_handle));
TF_ASSERT_OK(tensor_handle->Tensor(&input_tensor));
// Send input_tensor to the remote device and execute MatMulFunction on the
// remote device.
// Send input_tensor to the remote device, execute MatMulFunction on the
// remote device, and send the output back.
FunctionLibraryRuntime::Options opts;
const uint64 op_id = 2;
opts.op_id = op_id;
Notification execute_done;
std::vector<Tensor> inputs = {*input_tensor};
std::vector<Tensor> outputs;
@ -696,7 +699,8 @@ TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) {
});
execute_done.WaitForNotification();
TF_ASSERT_OK(status);
CheckOutputsAndClose(op_id);
EXPECT_EQ(outputs.size(), 1);
CheckOutputTensorAndClose(outputs.at(0));
}
// Test executes a remote function through KernelAndDeviceFunc.

View File

@ -26,6 +26,8 @@ limitations under the License.
namespace tensorflow {
namespace eager {
const int64 kInvalidRemoteOpId = -1;
// This class manages the states required to setup an eager cluster.
// TODO(fishx): Move remote state from context to this class.
class RemoteMgr {

View File

@ -73,7 +73,12 @@ message QueueItem {
}
message QueueResponse {
// `shape` and `tensor` cannot be set in the same response.
// Shapes of output tensors for creating remote TensorHandles.
repeated TensorShapeProto shape = 1;
// Output tensors of a remote function. Set when Operation.id is invalid.
repeated TensorProto tensor = 2;
}
message CreateContextRequest {

View File

@ -862,7 +862,10 @@ cuda_py_test(
":def_function",
":remote",
":test",
"//tensorflow/python:dtypes",
"//tensorflow/python:functional_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_spec",
"@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],

View File

@ -31,11 +31,14 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import server_lib
@ -167,6 +170,26 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
with ops.device('/job:worker/task:0'):
self.assertAllEqual(func(), 1)
@test_util.eager_lazy_remote_copy_on_and_off
def testRemoteCall(self):
@def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def _remote_fn(x):
return constant_op.constant(1) + x
remote_fn = _remote_fn.get_concrete_function()
@def_function.function
def func(x):
return functional_ops.remote_call(
args=[x],
Tout=[dtypes.int32],
f=remote_fn,
target='/job:worker/task:0')
self.assertAllEqual(func(constant_op.constant(1)), [2])
class RemoteAsyncTest(test.TestCase):