Support copying remote output tensors back from a remote device.
PiperOrigin-RevId: 303771801 Change-Id: I9cbbb1e156ad0ee580b4b56b660756aa911d385a
This commit is contained in:
parent
a60b5ac5fb
commit
238d1a70a7
@ -31,12 +31,6 @@ void EagerProcessFunctionLibraryRuntime::RunRemoteDevice(
|
||||
FunctionLibraryRuntime::Handle local_handle,
|
||||
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
|
||||
FunctionLibraryRuntime::DoneCallback done) const {
|
||||
if (!rets->empty()) {
|
||||
done(
|
||||
errors::Unimplemented("Remote outputs are not supported by "
|
||||
"EagerClusterFunctionLibraryRuntime yet."));
|
||||
return;
|
||||
}
|
||||
parent_->Run(opts, local_handle, args, rets, std::move(done));
|
||||
}
|
||||
|
||||
|
@ -120,15 +120,11 @@ void EagerClusterFunctionLibraryRuntime::Run(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
|
||||
FunctionLibraryRuntime::Options opts_copy = opts;
|
||||
if (!opts_copy.op_id.has_value()) {
|
||||
opts_copy.op_id = ctx_->RemoteMgr()->NextOpId();
|
||||
}
|
||||
std::vector<FunctionArg> function_args;
|
||||
for (const auto& tensor : args) {
|
||||
function_args.push_back(tensor);
|
||||
}
|
||||
Run(opts_copy, handle, function_args, rets, std::move(done));
|
||||
std::vector<FunctionArg> function_args;
|
||||
for (const auto& tensor : args) {
|
||||
function_args.push_back(tensor);
|
||||
}
|
||||
Run(opts, handle, function_args, rets, std::move(done));
|
||||
}
|
||||
|
||||
void EagerClusterFunctionLibraryRuntime::Run(
|
||||
@ -165,11 +161,6 @@ void EagerClusterFunctionLibraryRuntime::Run(
|
||||
|
||||
EagerOperation* op = function_data->op.get();
|
||||
|
||||
if (!opts.op_id.has_value()) {
|
||||
done(
|
||||
errors::Internal("op_id is not set for remote function: ", op->Name()));
|
||||
}
|
||||
|
||||
eager::EnqueueRequest* request = new eager::EnqueueRequest;
|
||||
request->set_context_id(context_id_);
|
||||
eager::Operation* remote_op = request->add_queue()->mutable_operation();
|
||||
@ -187,7 +178,11 @@ void EagerClusterFunctionLibraryRuntime::Run(
|
||||
// The remote component function should use the same op_id as its parent
|
||||
// multi-device function's in order to get the global unique op_id generated
|
||||
// by the master context.
|
||||
remote_op->set_id(opts.op_id.value());
|
||||
if (opts.op_id.has_value()) {
|
||||
remote_op->set_id(opts.op_id.value());
|
||||
} else {
|
||||
remote_op->set_id(kInvalidRemoteOpId);
|
||||
}
|
||||
remote_op->set_is_function(true);
|
||||
remote_op->set_is_component_function(true);
|
||||
remote_op->set_func_step_id(opts.step_id);
|
||||
@ -203,15 +198,39 @@ void EagerClusterFunctionLibraryRuntime::Run(
|
||||
// disabled, Run() returns when the remote function execution completes, which
|
||||
// might be blocked by a non-enqueued function execution.
|
||||
EnqueueResponse* response = new EnqueueResponse;
|
||||
eager_client->EnqueueAsync(request, response,
|
||||
[op, request, response, done](const Status& s) {
|
||||
for (auto handle : op->Inputs()) {
|
||||
handle->Unref();
|
||||
}
|
||||
done(s);
|
||||
delete request;
|
||||
delete response;
|
||||
});
|
||||
eager_client->EnqueueAsync(
|
||||
request, response,
|
||||
[op, request, response, rets, done = std::move(done)](const Status& s) {
|
||||
Status status = s;
|
||||
auto cleanup = gtl::MakeCleanup([request, response, &status, &done] {
|
||||
done(status);
|
||||
delete request;
|
||||
delete response;
|
||||
});
|
||||
|
||||
for (auto handle : op->Inputs()) {
|
||||
handle->Unref();
|
||||
}
|
||||
if (!status.ok()) {
|
||||
return;
|
||||
}
|
||||
if (response->queue_response_size() != 1) {
|
||||
status.Update(errors::Internal(
|
||||
"Expect that the size of response queue equals 1, but got: ",
|
||||
response->queue_response_size()));
|
||||
return;
|
||||
}
|
||||
for (const auto& tensor_proto : response->queue_response(0).tensor()) {
|
||||
Tensor t;
|
||||
if (t.FromProto(tensor_proto)) {
|
||||
rets->push_back(std::move(t));
|
||||
} else {
|
||||
status.Update(errors::Internal("Could not convert tensor proto: ",
|
||||
tensor_proto.DebugString()));
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void EagerClusterFunctionLibraryRuntime::CleanUp(
|
||||
|
@ -327,6 +327,13 @@ Status EagerServiceImpl::CreateMasterContext(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
TF_RETURN_IF_ERROR(handle->Tensor(&t));
|
||||
t->AsProtoTensorContent(proto);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
|
||||
@ -412,12 +419,21 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
||||
VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
|
||||
TF_RETURN_IF_ERROR(EagerExecute(op.get(), retvals.data(), &num_retvals));
|
||||
|
||||
eager_context->RemoteMgr()->AddOperationOutputs(
|
||||
absl::MakeSpan(retvals.data(), num_retvals), operation.id());
|
||||
|
||||
for (int i = 0; i < num_retvals; i++) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorHandleShape(retvals[i], queue_response->add_shape()));
|
||||
if (operation.id() == kInvalidRemoteOpId) {
|
||||
// Copy the output tensors back along with the response, since the op id
|
||||
// is invalid which cannot be added to RemoteMgr.
|
||||
for (int i = 0; i < num_retvals; i++) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorHandleProto(retvals[i], queue_response->add_tensor()));
|
||||
retvals[i]->Unref();
|
||||
}
|
||||
} else {
|
||||
eager_context->RemoteMgr()->AddOperationOutputs(
|
||||
absl::MakeSpan(retvals.data(), num_retvals), operation.id());
|
||||
for (int i = 0; i < num_retvals; i++) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorHandleShape(retvals[i], queue_response->add_shape()));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -560,13 +560,8 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
|
||||
/*thread_pool=*/nullptr, eager_cluster_flr_.get());
|
||||
}
|
||||
|
||||
void CheckOutputsAndClose(const int64 op_id) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* tensor_handle;
|
||||
TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
|
||||
context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
||||
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
||||
auto actual = t->flat<float>();
|
||||
void CheckOutputTensorAndClose(const Tensor& tensor) {
|
||||
auto actual = tensor.flat<float>();
|
||||
EXPECT_EQ(4, actual.size());
|
||||
EXPECT_EQ(7, actual(0));
|
||||
EXPECT_EQ(10, actual(1));
|
||||
@ -581,6 +576,15 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
|
||||
&close_context_response));
|
||||
}
|
||||
|
||||
void CheckOutputsAndClose(const int64 op_id) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* tensor_handle;
|
||||
TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
|
||||
context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
||||
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
||||
CheckOutputTensorAndClose(*t);
|
||||
}
|
||||
|
||||
protected:
|
||||
const string local_device_ = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const string remote_device_ = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
@ -649,8 +653,9 @@ TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) {
|
||||
CheckOutputsAndClose(op_id);
|
||||
}
|
||||
|
||||
// Test executes a remote function with a local tensor input.
|
||||
TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) {
|
||||
// Test executes a remote function with local input and output tensors.
|
||||
TEST_F(FunctionWithRemoteInputsTest,
|
||||
EagerClusterFLRTestWithLocalInputAndOutput) {
|
||||
Init();
|
||||
// Instantiate MatMulFunction on remote_device.
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
@ -681,11 +686,9 @@ TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) {
|
||||
context_id_, RemoteTensorHandleInternal(1, 0), &tensor_handle));
|
||||
TF_ASSERT_OK(tensor_handle->Tensor(&input_tensor));
|
||||
|
||||
// Send input_tensor to the remote device and execute MatMulFunction on the
|
||||
// remote device.
|
||||
// Send input_tensor to the remote device, execute MatMulFunction on the
|
||||
// remote device, and send the output back.
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
const uint64 op_id = 2;
|
||||
opts.op_id = op_id;
|
||||
Notification execute_done;
|
||||
std::vector<Tensor> inputs = {*input_tensor};
|
||||
std::vector<Tensor> outputs;
|
||||
@ -696,7 +699,8 @@ TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) {
|
||||
});
|
||||
execute_done.WaitForNotification();
|
||||
TF_ASSERT_OK(status);
|
||||
CheckOutputsAndClose(op_id);
|
||||
EXPECT_EQ(outputs.size(), 1);
|
||||
CheckOutputTensorAndClose(outputs.at(0));
|
||||
}
|
||||
|
||||
// Test executes a remote function through KernelAndDeviceFunc.
|
||||
|
@ -26,6 +26,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
const int64 kInvalidRemoteOpId = -1;
|
||||
|
||||
// This class manages the states required to setup an eager cluster.
|
||||
// TODO(fishx): Move remote state from context to this class.
|
||||
class RemoteMgr {
|
||||
|
@ -73,7 +73,12 @@ message QueueItem {
|
||||
}
|
||||
|
||||
message QueueResponse {
|
||||
// `shape` and `tensor` cannot be set in the same response.
|
||||
// Shapes of output tensors for creating remote TensorHandles.
|
||||
repeated TensorShapeProto shape = 1;
|
||||
|
||||
// Output tensors of a remote function. Set when Operation.id is invalid.
|
||||
repeated TensorProto tensor = 2;
|
||||
}
|
||||
|
||||
message CreateContextRequest {
|
||||
|
@ -862,7 +862,10 @@ cuda_py_test(
|
||||
":def_function",
|
||||
":remote",
|
||||
":test",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
@ -31,11 +31,14 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import remote
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import server_lib
|
||||
@ -167,6 +170,26 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.device('/job:worker/task:0'):
|
||||
self.assertAllEqual(func(), 1)
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testRemoteCall(self):
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
|
||||
def _remote_fn(x):
|
||||
return constant_op.constant(1) + x
|
||||
|
||||
remote_fn = _remote_fn.get_concrete_function()
|
||||
|
||||
@def_function.function
|
||||
def func(x):
|
||||
return functional_ops.remote_call(
|
||||
args=[x],
|
||||
Tout=[dtypes.int32],
|
||||
f=remote_fn,
|
||||
target='/job:worker/task:0')
|
||||
|
||||
self.assertAllEqual(func(constant_op.constant(1)), [2])
|
||||
|
||||
|
||||
class RemoteAsyncTest(test.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user