Using rendezvous manager to pass args / rets between devices during function remote execution. This enables CPU->GPU remote device executions now.

PiperOrigin-RevId: 168038285
This commit is contained in:
Rohan Jain 2017-09-08 13:30:17 -07:00 committed by TensorFlower Gardener
parent 82cc6529f4
commit 450c3b5626
16 changed files with 747 additions and 174 deletions

View File

@ -24,6 +24,8 @@ py_test(
"//tensorflow/python:functional_ops", "//tensorflow/python:functional_ops",
"//tensorflow/python:gradients", "//tensorflow/python:gradients",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:training", "//tensorflow/python:training",
"//third_party/py/numpy", "//third_party/py/numpy",
], ],

View File

@ -27,10 +27,13 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import function from tensorflow.python.framework import function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
@ -420,7 +423,7 @@ class IteratorTest(test.TestCase):
def testRemoteIteratorUsingRemoteCallOpDirectSession(self): def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
worker_config = config_pb2.ConfigProto() worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 2 worker_config.device_count["CPU"] = 3
with ops.device("/job:localhost/replica:0/task:0/cpu:1"): with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
@ -448,12 +451,12 @@ class IteratorTest(test.TestCase):
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
}) })
self.assertEqual(elem, [1]) self.assertEqual(elem, [1])
# Fails when target is cpu:0 where the resource is not located. # Fails when target is cpu:2 where the resource is not located.
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
sess.run( sess.run(
remote_op, remote_op,
feed_dict={ feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" target_placeholder: "/job:localhost/replica:0/task:0/cpu:2"
}) })
elem = sess.run( elem = sess.run(
remote_op, remote_op,
@ -474,6 +477,61 @@ class IteratorTest(test.TestCase):
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
}) })
def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
iterator_3 = dataset_3.make_one_shot_iterator()
iterator_3_handle = iterator_3.string_handle()
def _encode_raw(byte_array):
return "".join([chr(item) for item in byte_array])
@function.Defun(dtypes.uint8)
def _remote_fn(h):
handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
remote_iterator = dataset_ops.Iterator.from_string_handle(
handle, dataset_3.output_types, dataset_3.output_shapes)
return remote_iterator.get_next()
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
iterator_3_handle_uint8 = parsing_ops.decode_raw(
bytes=iterator_3_handle, out_type=dtypes.uint8)
remote_op = functional_ops.remote_call(
args=[iterator_3_handle_uint8],
Tout=[dtypes.int32],
f=_remote_fn,
target=target_placeholder)
with self.test_session() as sess:
elem = sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
self.assertEqual(elem, [1])
elem = sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
self.assertEqual(elem, [2])
elem = sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
self.assertEqual(elem, [3])
with self.assertRaises(errors.OutOfRangeError):
sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -1785,6 +1785,7 @@ tf_cuda_library(
"common_runtime/process_util.cc", "common_runtime/process_util.cc",
"common_runtime/renamed_device.cc", "common_runtime/renamed_device.cc",
"common_runtime/rendezvous_mgr.cc", "common_runtime/rendezvous_mgr.cc",
"common_runtime/rendezvous_util.cc",
"common_runtime/resource_variable_read_optimizer.cc", "common_runtime/resource_variable_read_optimizer.cc",
"common_runtime/session.cc", "common_runtime/session.cc",
"common_runtime/session_factory.cc", "common_runtime/session_factory.cc",
@ -1828,6 +1829,7 @@ tf_cuda_library(
"common_runtime/profile_handler.h", "common_runtime/profile_handler.h",
"common_runtime/renamed_device.h", "common_runtime/renamed_device.h",
"common_runtime/rendezvous_mgr.h", "common_runtime/rendezvous_mgr.h",
"common_runtime/rendezvous_util.h",
"common_runtime/session_factory.h", "common_runtime/session_factory.h",
"common_runtime/graph_execution_state.h", "common_runtime/graph_execution_state.h",
"common_runtime/placer.h", "common_runtime/placer.h",
@ -2669,29 +2671,29 @@ tf_cc_test(
srcs = ["common_runtime/process_function_library_runtime_test.cc"], srcs = ["common_runtime/process_function_library_runtime_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(), linkstatic = tf_kernel_tests_linkstatic(),
deps = [ deps = [
":core",
":core_cpu", ":core_cpu",
":core_cpu_internal", ":core_cpu_internal",
":direct_session_internal",
":framework", ":framework",
":framework_internal",
":lib",
":lib_internal",
":ops",
":protos_all_cc",
":test", ":test",
":test_main", ":test_main",
":testlib", ":testlib",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops", "//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:matmul_op", ],
"//tensorflow/core/kernels:shape_ops", )
"//third_party/eigen3",
tf_cc_test(
name = "common_runtime_rendezvous_util_test",
size = "small",
srcs = ["common_runtime/rendezvous_util_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":core_cpu_internal",
":lib",
":test",
":test_main",
], ],
) )

View File

@ -213,6 +213,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
FunctionBody** g_body); FunctionBody** g_body);
bool IsLocalTarget(const AttrSlice& attrs); bool IsLocalTarget(const AttrSlice& attrs);
AttrValueMap FixAttrs(const AttrSlice& attrs); AttrValueMap FixAttrs(const AttrSlice& attrs);
void RunRemote(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
Executor::Args* exec_args, Item* item, DoneCallback done);
TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl); TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
}; };
@ -557,52 +560,130 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
return Status::OK(); return Status::OK();
} }
void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
Executor::Args* exec_args,
Item* item, DoneCallback done) {
FunctionCallFrame* frame = exec_args->call_frame;
string target_device = parent_->GetDeviceName(handle);
string source_device = opts.source_device;
Rendezvous* rendezvous = opts.rendezvous;
// TODO(rohanj): Handle alloc_attrs in Rendezvous::Args.
Rendezvous::Args rendez_args;
Status s =
parent_->GetDeviceContext(target_device, &rendez_args.device_context);
if (!s.ok()) {
delete frame;
delete exec_args;
done(s);
return;
}
// The ProcFLR sends the arguments to the function from the source_device to
// the target_device. So here we receive those arguments. Similarly, when the
// computation is done and stored in *rets, we send the return values back
// to the source_device (caller) so that the ProcFLR can receive them later.
std::vector<Tensor>* remote_args = new std::vector<Tensor>;
ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
source_device, target_device, "arg_", args.size(), rendez_args,
rendezvous, remote_args,
[frame, remote_args, item, source_device, target_device, rendezvous,
rendez_args, rets, done, exec_args](const Status& status) {
Status s = status;
s = frame->SetArgs(*remote_args);
if (!s.ok()) {
delete frame;
delete remote_args;
delete exec_args;
done(s);
return;
}
item->exec->RunAsync(
*exec_args,
[item, frame, rets, done, source_device, target_device, rendezvous,
rendez_args, remote_args, exec_args](const Status& status) {
item->Unref();
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
}
delete frame;
if (!s.ok()) {
delete remote_args;
delete exec_args;
done(s);
return;
}
s = ProcessFunctionLibraryRuntime::SendTensors(
target_device, source_device, "ret_", *rets, rendez_args,
rendezvous);
delete remote_args;
delete exec_args;
done(s);
});
});
}
void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, std::vector<Tensor>* rets,
DoneCallback done) { DoneCallback done) {
if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
return done(errors::Cancelled("")); done(errors::Cancelled(""));
return;
} }
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
return parent_->Run(opts, handle, args, rets, done); parent_->Run(opts, handle, args, rets, done);
return;
} }
const FunctionBody* fbody = GetFunctionBody(handle); const FunctionBody* fbody = GetFunctionBody(handle);
FunctionCallFrame* frame = FunctionCallFrame* frame =
new FunctionCallFrame(fbody->arg_types, fbody->ret_types); new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
Status s = frame->SetArgs(args);
if (!s.ok()) {
delete frame;
return done(s);
}
Item* item = nullptr; Item* item = nullptr;
s = GetOrCreateItem(handle, &item); Status s = GetOrCreateItem(handle, &item);
if (!s.ok()) { if (!s.ok()) {
delete frame; delete frame;
return done(s); done(s);
return;
} }
DCHECK(opts.runner != nullptr); DCHECK(opts.runner != nullptr);
Executor::Args exec_args; Executor::Args* exec_args = new Executor::Args;
// Inherit the step_id from the caller. // Inherit the step_id from the caller.
exec_args.step_id = opts.step_id; exec_args->step_id = opts.step_id;
exec_args.rendezvous = opts.rendezvous; exec_args->rendezvous = opts.rendezvous;
exec_args.stats_collector = opts.stats_collector; exec_args->stats_collector = opts.stats_collector;
exec_args.call_frame = frame; exec_args->call_frame = frame;
exec_args.cancellation_manager = opts.cancellation_manager; exec_args->cancellation_manager = opts.cancellation_manager;
exec_args.step_container = opts.step_container; exec_args->step_container = opts.step_container;
exec_args.runner = *opts.runner; exec_args->runner = *opts.runner;
if (opts.remote_execution) {
RunRemote(opts, handle, args, rets, exec_args, item, done);
return;
}
s = frame->SetArgs(args);
if (!s.ok()) {
delete frame;
delete exec_args;
done(s);
return;
}
item->exec->RunAsync( item->exec->RunAsync(
// Executor args // Executor args
exec_args, *exec_args,
// Done callback. // Done callback.
[item, frame, rets, done](const Status& status) { [item, frame, rets, done, exec_args](const Status& status) {
item->Unref(); item->Unref();
Status s = status; Status s = status;
if (s.ok()) { if (s.ok()) {
s = frame->GetRetvals(rets); s = frame->ConsumeRetvals(rets);
} }
delete frame; delete frame;
delete exec_args;
done(s); done(s);
}); });
} }

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
@ -155,6 +156,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
} }
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle, Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime::Options opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets) { const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
std::atomic<int32> call_count(0); std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner = std::function<void(std::function<void()>)> runner =
@ -164,7 +166,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
}; };
Notification done; Notification done;
FunctionLibraryRuntime::Options opts;
opts.runner = &runner; opts.runner = &runner;
std::vector<Tensor> out; std::vector<Tensor> out;
Status status; Status status;
@ -205,7 +206,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
if (!status.ok()) { if (!status.ok()) {
return status; return status;
} }
return Run(flr, handle, args, std::move(rets)); FunctionLibraryRuntime::Options opts;
return Run(flr, handle, opts, args, std::move(rets));
} }
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr, std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
@ -963,15 +965,21 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
{{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle)); {{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle));
Tensor y; Tensor y;
FunctionLibraryRuntime::Options opts;
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
opts.source_device = "/device:CPU:1";
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1. // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
TF_CHECK_OK(Run(flr1_, handle, {}, {&y})); TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
test::ExpectTensorEqual<string>( test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"}, y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
TensorShape({}))); TensorShape({})));
TF_CHECK_OK(Run(flr2_, handle, {}, {&y})); opts.remote_execution = true;
opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
test::ExpectTensorEqual<string>( test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"}, y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
TensorShape({}))); TensorShape({})));
opts.rendezvous->Unref();
} }
namespace { namespace {

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <utility> #include <utility>
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
namespace tensorflow { namespace tensorflow {
@ -57,6 +58,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
} }
} }
/* static */
string ProcessFunctionLibraryRuntime::ObtainFunctionTarget( string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
const AttrSlice& attrs) { const AttrSlice& attrs) {
const AttrValue* value; const AttrValue* value;
@ -66,6 +68,63 @@ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
return value->s(); return value->s();
} }
/* static */
Status ProcessFunctionLibraryRuntime::SendTensors(
const string& source_device, const string& target_device,
const string& key_prefix, gtl::ArraySlice<Tensor> tensors_to_send,
const Rendezvous::Args& args, Rendezvous* rendezvous) {
std::vector<string> keys;
for (int i = 0; i < tensors_to_send.size(); ++i) {
string name = strings::StrCat(key_prefix, i);
string key = Rendezvous::CreateKey(source_device, i, target_device, name,
FrameAndIter(0, 0));
keys.push_back(key);
}
TF_RETURN_IF_ERROR(
SendTensorsToRendezvous(rendezvous, args, keys, tensors_to_send));
return Status::OK();
}
/* static */
void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
const string& source_device, const string& target_device,
const string& key_prefix, int64 num_tensors, const Rendezvous::Args& args,
Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
const StatusCallback& done) {
std::vector<string> keys;
for (int64 i = 0; i < num_tensors; ++i) {
string name = strings::StrCat(key_prefix, i);
string key = Rendezvous::CreateKey(source_device, i, target_device, name,
FrameAndIter(0, 0));
keys.push_back(key);
}
RecvOutputsFromRendezvousAsync(
rendezvous, args, keys, received_tensors,
[done](const Status& status) { done(status); });
}
Status ProcessFunctionLibraryRuntime::GetDeviceContext(
const string& device_name, DeviceContext** device_context) {
*device_context = nullptr;
FunctionLibraryRuntime* flr = GetFLR(device_name);
if (flr == nullptr) {
return errors::InvalidArgument("Device name: ", device_name, " not found.");
}
Device* device = flr->device();
string device_type = device->parsed_name().type;
if (device_type == "CPU") return Status::OK();
if (device_type == "GPU") {
auto* dev_info = flr->device()->tensorflow_gpu_device_info();
if (dev_info) {
*device_context = dev_info->default_context;
return Status::OK();
}
}
return errors::Internal("Device type: ", device_type,
" is currently unsupported for remote ",
"function executions");
}
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
const string& device_name) { const string& device_name) {
if (flr_map_.find(device_name) == flr_map_.end()) { if (flr_map_.find(device_name) == flr_map_.end()) {
@ -105,6 +164,7 @@ FunctionLibraryRuntime::LocalHandle
ProcessFunctionLibraryRuntime::GetHandleOnDevice( ProcessFunctionLibraryRuntime::GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) { const string& device_name, FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_); mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
std::pair<string, FunctionLibraryRuntime::LocalHandle> p = std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
function_data_[handle]; function_data_[handle];
if (p.first != device_name) { if (p.first != device_name) {
@ -113,6 +173,15 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
return p.second; return p.second;
} }
string ProcessFunctionLibraryRuntime::GetDeviceName(
FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
function_data_[handle];
return p.first;
}
Status ProcessFunctionLibraryRuntime::Instantiate( Status ProcessFunctionLibraryRuntime::Instantiate(
const string& function_name, AttrSlice attrs, const string& function_name, AttrSlice attrs,
FunctionLibraryRuntime::Handle* handle) { FunctionLibraryRuntime::Handle* handle) {
@ -129,15 +198,58 @@ void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts, const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args, FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) { std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
if (!opts.remote_execution) {
done(errors::InvalidArgument(
"ProcessFunctionLibraryRuntime::Run should only be called when there ",
"is a remote execution."));
return;
}
FunctionLibraryRuntime* flr = nullptr; FunctionLibraryRuntime* flr = nullptr;
string target_device;
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
std::pair<string, FunctionLibraryRuntime::LocalHandle> p = std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
function_data_[handle]; function_data_[handle];
target_device = p.first;
flr = GetFLR(p.first); flr = GetFLR(p.first);
} }
if (flr != nullptr) { if (flr != nullptr) {
return flr->Run(opts, handle, args, rets, std::move(done)); auto rendezvous = opts.rendezvous;
string source_device = opts.source_device;
Rendezvous::Args rendez_args;
Status s = GetDeviceContext(source_device, &rendez_args.device_context);
if (!s.ok()) {
done(s);
return;
}
// Send the args over to the target device.
s = SendTensors(source_device, target_device, "arg_", args, rendez_args,
rendezvous);
if (!s.ok()) {
done(s);
return;
}
std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
flr->Run(opts, handle, args, remote_rets,
[source_device, target_device, rendezvous, remote_rets, rets, done,
rendez_args](const Status& status) {
if (!status.ok()) {
delete remote_rets;
done(status);
return;
}
int64 num_returns = remote_rets->size();
delete remote_rets;
// Now receive the return values from the target.
ReceiveTensorsAsync(target_device, source_device, "ret_",
num_returns, rendez_args, rendezvous, rets,
done);
});
} else {
done(errors::Internal("Could not find device"));
return;
} }
} }

View File

@ -45,6 +45,31 @@ class ProcessFunctionLibraryRuntime {
// attribute, returns "". Canonicalizes the device name. // attribute, returns "". Canonicalizes the device name.
static string ObtainFunctionTarget(const AttrSlice& attrs); static string ObtainFunctionTarget(const AttrSlice& attrs);
// Sends `tensors_to_send` from `source_device` to `target_device` using
// `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
// Rendezvous. Method takes references on each of the `tensors_to_send`.
// Method doesn't block.
static Status SendTensors(const string& source_device,
const string& target_device,
const string& key_prefix,
gtl::ArraySlice<Tensor> tensors_to_send,
const Rendezvous::Args& args,
Rendezvous* rendezvous);
typedef std::function<void(const Status&)> StatusCallback;
// Receives `received_tensors` from `target_device` (originally sent from
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
// keys to be retrieved. Method doesn't block and calls `done` when
// `num_tensors` are fetched.
static void ReceiveTensorsAsync(const string& source_device,
const string& target_device,
const string& key_prefix, int64 num_tensors,
const Rendezvous::Args& args,
Rendezvous* rendezvous,
std::vector<Tensor>* received_tensors,
const StatusCallback& done);
static const char kDefaultFLRDevice[]; static const char kDefaultFLRDevice[];
// Returns the FunctionLibraryRuntime for the corresponding device_name. // Returns the FunctionLibraryRuntime for the corresponding device_name.
FunctionLibraryRuntime* GetFLR(const string& device_name); FunctionLibraryRuntime* GetFLR(const string& device_name);
@ -85,6 +110,17 @@ class ProcessFunctionLibraryRuntime {
FunctionLibraryRuntime::DoneCallback done); FunctionLibraryRuntime::DoneCallback done);
private: private:
// For a given device_name, returns a DeviceContext for copying
// tensors to/from the device.
Status GetDeviceContext(const string& device_name,
DeviceContext** device_context);
// Looks up the information for the given `handle` and returns the name
// of the device where the function is registered.
string GetDeviceName(FunctionLibraryRuntime::Handle handle);
friend class FunctionLibraryRuntimeImpl;
mutable mutex mu_; mutable mutex mu_;
// Holds all the function invocations here. // Holds all the function invocations here.

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
@ -43,10 +44,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
proc_flr_.reset(new ProcessFunctionLibraryRuntime( proc_flr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts)); opts));
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
} }
Status Run(const string& name, test::function::Attrs attrs, Status Run(const string& name, FunctionLibraryRuntime::Options opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets) { test::function::Attrs attrs, const std::vector<Tensor>& args,
std::vector<Tensor*> rets) {
FunctionLibraryRuntime::Handle handle; FunctionLibraryRuntime::Handle handle;
Status status = proc_flr_->Instantiate(name, attrs, &handle); Status status = proc_flr_->Instantiate(name, attrs, &handle);
if (!status.ok()) { if (!status.ok()) {
@ -61,7 +64,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
}; };
Notification done; Notification done;
FunctionLibraryRuntime::Options opts;
opts.runner = &runner; opts.runner = &runner;
std::vector<Tensor> out; std::vector<Tensor> out;
proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) { proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
@ -86,6 +88,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_; std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_; std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
IntraProcessRendezvous* rendezvous_;
}; };
TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) { TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
@ -99,6 +102,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
EXPECT_EQ(flr->device(), devices_[1]); EXPECT_EQ(flr->device(), devices_[1]);
flr = proc_flr_->GetFLR("abc"); flr = proc_flr_->GetFLR("abc");
EXPECT_EQ(flr, nullptr); EXPECT_EQ(flr, nullptr);
rendezvous_->Unref();
} }
TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) { TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
@ -118,69 +122,94 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) { TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
Init({test::function::XTimesTwo()}); Init({test::function::XTimesTwo()});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
auto x = test::AsTensor<float>({1, 2, 3, 4}); auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y; Tensor y;
TF_CHECK_OK( TF_CHECK_OK(
Run("XTimesTwo", Run("XTimesTwo", opts,
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
{&y})); {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
rendezvous_->Unref();
} }
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) { TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
Init({test::function::FindDevice()}); Init({test::function::FindDevice()});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y; Tensor y;
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, TF_CHECK_OK(Run("FindDevice", opts,
{}, {&y})); {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
test::ExpectTensorEqual<string>( test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"}, y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
TensorShape({}))); TensorShape({})));
rendezvous_->Unref();
} }
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) { TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
Init({test::function::XTimesTwo(), test::function::XTimesFour()}); Init({test::function::XTimesTwo(), test::function::XTimesFour()});
auto x = test::AsTensor<float>({1, 2, 3, 4}); auto x = test::AsTensor<float>({1, 2, 3, 4});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y; Tensor y;
TF_CHECK_OK( TF_CHECK_OK(
Run("XTimesTwo", Run("XTimesTwo", opts,
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
{&y})); {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
TF_CHECK_OK( TF_CHECK_OK(
Run("XTimesFour", Run("XTimesFour", opts,
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
{&y})); {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
rendezvous_->Unref();
} }
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) { TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
Init({test::function::FindDevice()}); Init({test::function::FindDevice()});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y; Tensor y;
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, TF_CHECK_OK(Run("FindDevice", opts,
{}, {&y})); {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>( test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"}, y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
TensorShape({}))); TensorShape({})));
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, TF_CHECK_OK(Run("FindDevice", opts,
{}, {&y})); {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>( test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"}, y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
TensorShape({}))); TensorShape({})));
rendezvous_->Unref();
} }
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) { TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
Init({test::function::FindDevice()}); Init({test::function::FindDevice()});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y; Tensor y;
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, TF_CHECK_OK(Run("FindDevice", opts,
{}, {&y})); {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
test::ExpectTensorEqual<string>( test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"}, y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
TensorShape({}))); TensorShape({})));
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, TF_CHECK_OK(Run("FindDevice", opts,
{}, {&y})); {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>( test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"}, y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
TensorShape({}))); TensorShape({})));
rendezvous_->Unref();
} }
} // anonymous namespace } // anonymous namespace

View File

@ -0,0 +1,119 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/rendezvous_util.h"
namespace tensorflow {
Status SendTensorsToRendezvous(Rendezvous* rendezvous,
const Rendezvous::Args& args,
const std::vector<string>& keys,
gtl::ArraySlice<Tensor> tensors_to_send) {
if (keys.size() != tensors_to_send.size()) {
return errors::InvalidArgument(
"keys and tensors_to_send are not the same size. keys.size() = ",
keys.size(), "; tensors_to_send.size() = ", tensors_to_send.size());
}
Rendezvous::ParsedKey parsed;
for (int i = 0; i < keys.size(); ++i) {
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(keys[i], &parsed));
TF_RETURN_IF_ERROR(
rendezvous->Send(parsed, args, tensors_to_send[i], false));
}
return Status::OK();
}
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
const Rendezvous::Args& args,
const std::vector<string>& keys,
std::vector<Tensor>* received_tensors,
const StatusCallback& done) {
if (keys.empty()) {
done(Status::OK());
return;
}
received_tensors->reserve(keys.size());
std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> arguments;
for (int i = 0; i < keys.size(); ++i) {
Rendezvous::ParsedKey parsed;
Status s = Rendezvous::ParseKey(keys[i], &parsed);
received_tensors->push_back(Tensor());
if (!s.ok()) {
done(s);
return;
}
arguments.push_back(
std::make_tuple(keys[i], &((*received_tensors)[i]), parsed));
}
typedef struct {
mutex mu;
int64 done_counter;
Status shared_status = Status::OK();
} CallState;
CallState* call_state = new CallState;
call_state->done_counter = keys.size();
for (auto& p : arguments) {
const string& key = std::get<0>(p);
Tensor* val = std::get<1>(p);
Rendezvous::ParsedKey parsed = std::get<2>(p);
rendezvous->RecvAsync(
parsed, args,
[val, done, key, call_state](const Status& s,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& v, const bool is_dead) {
Status status = s;
if (status.ok()) {
*val = v;
if (is_dead) {
status = errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
}
call_state->mu.lock();
call_state->shared_status.Update(status);
call_state->done_counter--;
// If we are the last async call to return, call the done callback.
if (call_state->done_counter == 0) {
const Status& final_status = call_state->shared_status;
call_state->mu.unlock();
done(final_status);
delete call_state;
return;
}
call_state->mu.unlock();
});
}
}
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
const Rendezvous::Args& args) {
// Receives values requested by the caller.
Rendezvous::ParsedKey parsed;
for (auto& p : *out) {
const string& key = p.first;
Tensor* val = &p.second;
bool is_dead = false;
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
TF_RETURN_IF_ERROR(rendezvous->Recv(parsed, args, val, &is_dead));
if (is_dead) {
return errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,44 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
#include <map>
#include "tensorflow/core/framework/rendezvous.h"
namespace tensorflow {
typedef std::map<string, Tensor> NamedTensors;
typedef std::function<void(const Status&)> StatusCallback;
// Uses `rendezvous` to send tensors in `in`.
Status SendTensorsToRendezvous(Rendezvous* rendezvous,
const Rendezvous::Args& args,
const std::vector<string>& keys,
gtl::ArraySlice<Tensor> tensors_to_send);
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
const Rendezvous::Args& args,
const std::vector<string>& keys,
std::vector<Tensor>* received_tensors,
const StatusCallback& done);
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
const Rendezvous::Args& args);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_

View File

@ -0,0 +1,94 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class RendezvousUtilTest : public ::testing::Test {
public:
RendezvousUtilTest() { rendez_ = NewLocalRendezvous(); }
~RendezvousUtilTest() override { rendez_->Unref(); }
Rendezvous* rendez_;
};
// string -> Tensor<string>
Tensor V(const string& content) {
Tensor tensor(DT_STRING, TensorShape({}));
tensor.scalar<string>()() = content;
return tensor;
}
// Tensor<string> -> string
string V(const Tensor& tensor) {
CHECK_EQ(tensor.dtype(), DT_STRING);
CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
return tensor.scalar<string>()();
}
string MakeStringKey(const string& name) {
return Rendezvous::CreateKey(
"/job:localhost/replica:0/task:0/device:CPU:0", 0,
"/job:localhost/replica:0/task:0/device:GPU:0", name, FrameAndIter(0, 0));
}
TEST_F(RendezvousUtilTest, SendBeforeRecv) {
// Fire off sends before receive the tensors.
Rendezvous::Args args;
TF_ASSERT_OK(SendTensorsToRendezvous(
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
{V("hello1"), V("hello2")}));
Notification n;
std::vector<Tensor> received_keys;
RecvOutputsFromRendezvousAsync(
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
&received_keys, [&n](const Status& status) { n.Notify(); });
n.WaitForNotification();
EXPECT_EQ(2, received_keys.size());
EXPECT_EQ("hello1", V(received_keys[0]));
EXPECT_EQ("hello2", V(received_keys[1]));
}
TEST_F(RendezvousUtilTest, RecvBeforeSend) {
// Fire off recvs, wait for a notification in the callback.
Rendezvous::Args args;
Notification n;
std::vector<Tensor> received_keys;
RecvOutputsFromRendezvousAsync(
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
&received_keys, [&n](const Status& status) { n.Notify(); });
TF_ASSERT_OK(SendTensorsToRendezvous(
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
{V("hello1"), V("hello2")}));
n.WaitForNotification();
EXPECT_EQ(2, received_keys.size());
EXPECT_EQ("hello1", V(received_keys[0]));
EXPECT_EQ("hello2", V(received_keys[1]));
}
} // namespace
} // namespace tensorflow

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/cancellation.h"
@ -321,116 +322,25 @@ Status GraphMgr::DeregisterAll() {
return Status::OK(); return Status::OK();
} }
Status GraphMgr::SendInputsToRendezvous(Rendezvous* rendezvous,
const NamedTensors& in) {
Rendezvous::ParsedKey parsed;
for (const auto& p : in) {
const string& key = p.first;
const Tensor& val = p.second;
Status s = Rendezvous::ParseKey(key, &parsed);
if (s.ok()) {
s = rendezvous->Send(parsed, Rendezvous::Args(), val, false);
}
if (!s.ok()) {
return s;
}
}
return Status::OK();
}
Status GraphMgr::RecvOutputsFromRendezvous(Rendezvous* rendezvous,
NamedTensors* out) {
// Receives values requested by the caller.
Rendezvous::ParsedKey parsed;
for (auto& p : *out) {
const string& key = p.first;
Tensor* val = &p.second;
bool is_dead = false;
Status s = Rendezvous::ParseKey(key, &parsed);
if (s.ok()) {
s = rendezvous->Recv(parsed, Rendezvous::Args(), val, &is_dead);
}
if (is_dead) {
s = errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
if (!s.ok()) return s;
}
return Status::OK();
}
void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
NamedTensors* out,
const StatusCallback& done) {
if (out->empty()) {
done(Status::OK());
return;
}
// We compute the args before calling RecvAsync because we need to ensure that
// out isn't being iterated over after done is called, since done deletes out.
std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> args;
for (auto& p : *out) {
Rendezvous::ParsedKey parsed;
Status s = Rendezvous::ParseKey(p.first, &parsed);
if (!s.ok()) {
done(s);
return;
}
args.push_back(std::make_tuple(p.first, &p.second, parsed));
}
typedef struct {
mutex mu;
int done_counter;
Status shared_status = Status::OK();
} CallState;
CallState* call_state = new CallState;
call_state->done_counter = out->size();
for (auto& p : args) {
const string& key = std::get<0>(p);
Tensor* val = std::get<1>(p);
Rendezvous::ParsedKey parsed = std::get<2>(p);
rendezvous->RecvAsync(
parsed, Rendezvous::Args(),
[val, done, key, call_state](const Status& s,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& v, const bool is_dead) {
Status status = s;
if (status.ok()) {
*val = v;
if (is_dead) {
status = errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
}
call_state->mu.lock();
call_state->shared_status.Update(status);
call_state->done_counter--;
// If we are the last async call to return, call the done callback.
if (call_state->done_counter == 0) {
const Status& final_status = call_state->shared_status;
call_state->mu.unlock();
done(final_status);
delete call_state;
return;
}
call_state->mu.unlock();
});
}
}
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) { Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = SendInputsToRendezvous(rendezvous, in); std::vector<string> keys;
std::vector<Tensor> tensors_to_send;
keys.reserve(in.size());
tensors_to_send.reserve(in.size());
for (const auto& p : in) {
keys.push_back(p.first);
tensors_to_send.push_back(p.second);
}
Status s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys,
tensors_to_send);
rendezvous->Unref(); rendezvous->Unref();
return s; return s;
} }
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = RecvOutputsFromRendezvous(rendezvous, out); Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
rendezvous->Unref(); rendezvous->Unref();
return s; return s;
} }
@ -438,11 +348,24 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
StatusCallback done) { StatusCallback done) {
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
RecvOutputsFromRendezvousAsync(rendezvous, out, std::vector<string> keys;
[done, rendezvous](const Status s) { std::vector<Tensor>* received_keys = new std::vector<Tensor>;
rendezvous->Unref(); keys.reserve(out->size());
done(s); received_keys->reserve(out->size());
}); for (const auto& p : *out) {
keys.push_back(p.first);
received_keys->push_back(p.second);
}
RecvOutputsFromRendezvousAsync(
rendezvous, Rendezvous::Args(), keys, received_keys,
[done, rendezvous, received_keys, out, keys](const Status s) {
rendezvous->Unref();
for (int i = 0; i < keys.size(); ++i) {
(*out)[keys[i]] = (*received_keys)[i];
}
delete received_keys;
done(s);
});
} }
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
@ -484,7 +407,16 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
// Sends values specified by the caller. // Sends values specified by the caller.
if (s.ok()) { if (s.ok()) {
s = SendInputsToRendezvous(rendezvous, in); std::vector<string> keys;
std::vector<Tensor> tensors_to_send;
keys.reserve(in.size());
tensors_to_send.reserve(in.size());
for (auto& p : in) {
keys.push_back(p.first);
tensors_to_send.push_back(p.second);
}
s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys,
tensors_to_send);
} }
if (!s.ok()) { if (!s.ok()) {

View File

@ -169,11 +169,6 @@ class GraphMgr {
void BuildCostModel(Item* item, StepStatsCollector* collector, void BuildCostModel(Item* item, StepStatsCollector* collector,
CostGraphDef* cost_graph); CostGraphDef* cost_graph);
Status SendInputsToRendezvous(Rendezvous* rendezvous, const NamedTensors& in);
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out);
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, NamedTensors* out,
const StatusCallback& done);
Status InitItem(const string& session, const GraphDef& gdef, Status InitItem(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options, const GraphOptions& graph_options,
const DebugOptions& debug_options, Item* item); const DebugOptions& debug_options, Item* item);

View File

@ -426,6 +426,10 @@ class FunctionLibraryRuntime {
StepStatsCollector* stats_collector = nullptr; StepStatsCollector* stats_collector = nullptr;
std::function<void(std::function<void()>)>* runner = nullptr; std::function<void(std::function<void()>)>* runner = nullptr;
// Parameters for remote function execution.
bool remote_execution = false;
string source_device = ""; // Fully specified device name.
}; };
typedef std::function<void(const Status&)> DoneCallback; typedef std::function<void(const Status&)> DoneCallback;
virtual void Run(const Options& opts, Handle handle, virtual void Run(const Options& opts, Handle handle,

View File

@ -292,7 +292,8 @@ class RemoteCallOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
AttrValueMap attr_values = func_->attr(); AttrValueMap attr_values = func_->attr();
AttrValue v; AttrValue v;
v.set_s(target->scalar<string>()()); const string& target_device = target->scalar<string>()();
v.set_s(target_device);
AddAttr("_target", v, &attr_values); AddAttr("_target", v, &attr_values);
FunctionLibraryRuntime* lib = ctx->function_library(); FunctionLibraryRuntime* lib = ctx->function_library();
@ -310,6 +311,11 @@ class RemoteCallOp : public AsyncOpKernel {
FunctionLibraryRuntime::Options opts; FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id(); opts.step_id = ctx->step_id();
opts.runner = ctx->runner(); opts.runner = ctx->runner();
opts.source_device = lib->device()->name();
if (opts.source_device != target_device) {
opts.remote_execution = true;
}
opts.rendezvous = ctx->rendezvous();
std::vector<Tensor> args; std::vector<Tensor> args;
args.reserve(arguments.size()); args.reserve(arguments.size());
for (const Tensor& argument : arguments) { for (const Tensor& argument : arguments) {
@ -334,10 +340,13 @@ class RemoteCallOp : public AsyncOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp); TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
}; };
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_CPU), RemoteCallOp); REGISTER_KERNEL_BUILDER(
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_GPU), RemoteCallOp); Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
REGISTER_KERNEL_BUILDER(
Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
#if TENSORFLOW_USE_SYCL #if TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_SYCL), RemoteCallOp); REGISTER_KERNEL_BUILDER(
Name("RemoteCall").Device(DEVICE_SYCL).HostMemory("target"), RemoteCallOp);
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow } // namespace tensorflow

View File

@ -500,6 +500,54 @@ class FunctionalOpsTest(test.TestCase):
mul = sess.run(remote_op) mul = sess.run(remote_op)
self.assertEqual(mul, [6]) self.assertEqual(mul, [6])
def testRemoteFunctionCPUGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@function.Defun(dtypes.float32, dtypes.float32)
def _remote_fn(a, b):
return math_ops.multiply(a, b)
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
a = variables.Variable(2, dtype=dtypes.float32)
b = variables.Variable(3, dtype=dtypes.float32)
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
remote_op = functional_ops.remote_call(
args=[a, b],
Tout=[dtypes.float32],
f=_remote_fn,
target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
mul = sess.run(remote_op)
self.assertEqual(mul, 9.0)
def testRemoteFunctionGPUCPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@function.Defun(dtypes.float32, dtypes.float32)
def _remote_fn(a, b):
return math_ops.multiply(a, b)
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
a = variables.Variable(2, dtype=dtypes.float32)
b = variables.Variable(3, dtype=dtypes.float32)
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
remote_op = functional_ops.remote_call(
args=[a, b],
Tout=[dtypes.float32],
f=_remote_fn,
target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
mul = sess.run(remote_op)
self.assertEqual(mul, 9.0)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()