Using rendezvous manager to pass args / rets between devices during function remote execution. This enables CPU->GPU remote device executions now.
PiperOrigin-RevId: 168038285
This commit is contained in:
parent
82cc6529f4
commit
450c3b5626
@ -24,6 +24,8 @@ py_test(
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
|
@ -27,10 +27,13 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
@ -420,7 +423,7 @@ class IteratorTest(test.TestCase):
|
||||
|
||||
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
|
||||
worker_config = config_pb2.ConfigProto()
|
||||
worker_config.device_count["CPU"] = 2
|
||||
worker_config.device_count["CPU"] = 3
|
||||
|
||||
with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
|
||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
@ -448,12 +451,12 @@ class IteratorTest(test.TestCase):
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
||||
})
|
||||
self.assertEqual(elem, [1])
|
||||
# Fails when target is cpu:0 where the resource is not located.
|
||||
# Fails when target is cpu:2 where the resource is not located.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(
|
||||
remote_op,
|
||||
feed_dict={
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:2"
|
||||
})
|
||||
elem = sess.run(
|
||||
remote_op,
|
||||
@ -474,6 +477,61 @@ class IteratorTest(test.TestCase):
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
||||
})
|
||||
|
||||
def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
|
||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
iterator_3 = dataset_3.make_one_shot_iterator()
|
||||
iterator_3_handle = iterator_3.string_handle()
|
||||
|
||||
def _encode_raw(byte_array):
|
||||
return "".join([chr(item) for item in byte_array])
|
||||
|
||||
@function.Defun(dtypes.uint8)
|
||||
def _remote_fn(h):
|
||||
handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
|
||||
remote_iterator = dataset_ops.Iterator.from_string_handle(
|
||||
handle, dataset_3.output_types, dataset_3.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
|
||||
target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||
iterator_3_handle_uint8 = parsing_ops.decode_raw(
|
||||
bytes=iterator_3_handle, out_type=dtypes.uint8)
|
||||
remote_op = functional_ops.remote_call(
|
||||
args=[iterator_3_handle_uint8],
|
||||
Tout=[dtypes.int32],
|
||||
f=_remote_fn,
|
||||
target=target_placeholder)
|
||||
|
||||
with self.test_session() as sess:
|
||||
elem = sess.run(
|
||||
remote_op,
|
||||
feed_dict={
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||
})
|
||||
self.assertEqual(elem, [1])
|
||||
elem = sess.run(
|
||||
remote_op,
|
||||
feed_dict={
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||
})
|
||||
self.assertEqual(elem, [2])
|
||||
elem = sess.run(
|
||||
remote_op,
|
||||
feed_dict={
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||
})
|
||||
self.assertEqual(elem, [3])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(
|
||||
remote_op,
|
||||
feed_dict={
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -1785,6 +1785,7 @@ tf_cuda_library(
|
||||
"common_runtime/process_util.cc",
|
||||
"common_runtime/renamed_device.cc",
|
||||
"common_runtime/rendezvous_mgr.cc",
|
||||
"common_runtime/rendezvous_util.cc",
|
||||
"common_runtime/resource_variable_read_optimizer.cc",
|
||||
"common_runtime/session.cc",
|
||||
"common_runtime/session_factory.cc",
|
||||
@ -1828,6 +1829,7 @@ tf_cuda_library(
|
||||
"common_runtime/profile_handler.h",
|
||||
"common_runtime/renamed_device.h",
|
||||
"common_runtime/rendezvous_mgr.h",
|
||||
"common_runtime/rendezvous_util.h",
|
||||
"common_runtime/session_factory.h",
|
||||
"common_runtime/graph_execution_state.h",
|
||||
"common_runtime/placer.h",
|
||||
@ -2669,29 +2671,29 @@ tf_cc_test(
|
||||
srcs = ["common_runtime/process_function_library_runtime_test.cc"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
deps = [
|
||||
":core",
|
||||
":core_cpu",
|
||||
":core_cpu_internal",
|
||||
":direct_session_internal",
|
||||
":framework",
|
||||
":framework_internal",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
":ops",
|
||||
":protos_all_cc",
|
||||
":test",
|
||||
":test_main",
|
||||
":testlib",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/core/kernels:cast_op",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:matmul_op",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "common_runtime_rendezvous_util_test",
|
||||
size = "small",
|
||||
srcs = ["common_runtime/rendezvous_util_test.cc"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
deps = [
|
||||
":core_cpu_internal",
|
||||
":lib",
|
||||
":test",
|
||||
":test_main",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -213,6 +213,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
FunctionBody** g_body);
|
||||
bool IsLocalTarget(const AttrSlice& attrs);
|
||||
AttrValueMap FixAttrs(const AttrSlice& attrs);
|
||||
void RunRemote(const Options& opts, Handle handle,
|
||||
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
||||
Executor::Args* exec_args, Item* item, DoneCallback done);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
|
||||
};
|
||||
@ -557,52 +560,130 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
|
||||
gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets,
|
||||
Executor::Args* exec_args,
|
||||
Item* item, DoneCallback done) {
|
||||
FunctionCallFrame* frame = exec_args->call_frame;
|
||||
string target_device = parent_->GetDeviceName(handle);
|
||||
string source_device = opts.source_device;
|
||||
Rendezvous* rendezvous = opts.rendezvous;
|
||||
// TODO(rohanj): Handle alloc_attrs in Rendezvous::Args.
|
||||
Rendezvous::Args rendez_args;
|
||||
Status s =
|
||||
parent_->GetDeviceContext(target_device, &rendez_args.device_context);
|
||||
if (!s.ok()) {
|
||||
delete frame;
|
||||
delete exec_args;
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
|
||||
// The ProcFLR sends the arguments to the function from the source_device to
|
||||
// the target_device. So here we receive those arguments. Similarly, when the
|
||||
// computation is done and stored in *rets, we send the return values back
|
||||
// to the source_device (caller) so that the ProcFLR can receive them later.
|
||||
std::vector<Tensor>* remote_args = new std::vector<Tensor>;
|
||||
ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
|
||||
source_device, target_device, "arg_", args.size(), rendez_args,
|
||||
rendezvous, remote_args,
|
||||
[frame, remote_args, item, source_device, target_device, rendezvous,
|
||||
rendez_args, rets, done, exec_args](const Status& status) {
|
||||
Status s = status;
|
||||
s = frame->SetArgs(*remote_args);
|
||||
if (!s.ok()) {
|
||||
delete frame;
|
||||
delete remote_args;
|
||||
delete exec_args;
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
item->exec->RunAsync(
|
||||
*exec_args,
|
||||
[item, frame, rets, done, source_device, target_device, rendezvous,
|
||||
rendez_args, remote_args, exec_args](const Status& status) {
|
||||
item->Unref();
|
||||
Status s = status;
|
||||
if (s.ok()) {
|
||||
s = frame->ConsumeRetvals(rets);
|
||||
}
|
||||
delete frame;
|
||||
if (!s.ok()) {
|
||||
delete remote_args;
|
||||
delete exec_args;
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
s = ProcessFunctionLibraryRuntime::SendTensors(
|
||||
target_device, source_device, "ret_", *rets, rendez_args,
|
||||
rendezvous);
|
||||
delete remote_args;
|
||||
delete exec_args;
|
||||
done(s);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets,
|
||||
DoneCallback done) {
|
||||
if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
|
||||
return done(errors::Cancelled(""));
|
||||
done(errors::Cancelled(""));
|
||||
return;
|
||||
}
|
||||
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
|
||||
return parent_->Run(opts, handle, args, rets, done);
|
||||
parent_->Run(opts, handle, args, rets, done);
|
||||
return;
|
||||
}
|
||||
const FunctionBody* fbody = GetFunctionBody(handle);
|
||||
FunctionCallFrame* frame =
|
||||
new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
|
||||
Status s = frame->SetArgs(args);
|
||||
if (!s.ok()) {
|
||||
delete frame;
|
||||
return done(s);
|
||||
}
|
||||
|
||||
Item* item = nullptr;
|
||||
s = GetOrCreateItem(handle, &item);
|
||||
Status s = GetOrCreateItem(handle, &item);
|
||||
if (!s.ok()) {
|
||||
delete frame;
|
||||
return done(s);
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
DCHECK(opts.runner != nullptr);
|
||||
|
||||
Executor::Args exec_args;
|
||||
Executor::Args* exec_args = new Executor::Args;
|
||||
// Inherit the step_id from the caller.
|
||||
exec_args.step_id = opts.step_id;
|
||||
exec_args.rendezvous = opts.rendezvous;
|
||||
exec_args.stats_collector = opts.stats_collector;
|
||||
exec_args.call_frame = frame;
|
||||
exec_args.cancellation_manager = opts.cancellation_manager;
|
||||
exec_args.step_container = opts.step_container;
|
||||
exec_args.runner = *opts.runner;
|
||||
exec_args->step_id = opts.step_id;
|
||||
exec_args->rendezvous = opts.rendezvous;
|
||||
exec_args->stats_collector = opts.stats_collector;
|
||||
exec_args->call_frame = frame;
|
||||
exec_args->cancellation_manager = opts.cancellation_manager;
|
||||
exec_args->step_container = opts.step_container;
|
||||
exec_args->runner = *opts.runner;
|
||||
|
||||
if (opts.remote_execution) {
|
||||
RunRemote(opts, handle, args, rets, exec_args, item, done);
|
||||
return;
|
||||
}
|
||||
|
||||
s = frame->SetArgs(args);
|
||||
if (!s.ok()) {
|
||||
delete frame;
|
||||
delete exec_args;
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
item->exec->RunAsync(
|
||||
// Executor args
|
||||
exec_args,
|
||||
*exec_args,
|
||||
// Done callback.
|
||||
[item, frame, rets, done](const Status& status) {
|
||||
[item, frame, rets, done, exec_args](const Status& status) {
|
||||
item->Unref();
|
||||
Status s = status;
|
||||
if (s.ok()) {
|
||||
s = frame->GetRetvals(rets);
|
||||
s = frame->ConsumeRetvals(rets);
|
||||
}
|
||||
delete frame;
|
||||
delete exec_args;
|
||||
done(s);
|
||||
});
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/function_testlib.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@ -155,6 +156,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
|
||||
FunctionLibraryRuntime::Options opts,
|
||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
||||
std::atomic<int32> call_count(0);
|
||||
std::function<void(std::function<void()>)> runner =
|
||||
@ -164,7 +166,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
Notification done;
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.runner = &runner;
|
||||
std::vector<Tensor> out;
|
||||
Status status;
|
||||
@ -205,7 +206,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
return Run(flr, handle, args, std::move(rets));
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
return Run(flr, handle, opts, args, std::move(rets));
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
|
||||
@ -963,15 +965,21 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
|
||||
{{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle));
|
||||
|
||||
Tensor y;
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
|
||||
opts.source_device = "/device:CPU:1";
|
||||
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
|
||||
TF_CHECK_OK(Run(flr1_, handle, {}, {&y}));
|
||||
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
|
||||
test::ExpectTensorEqual<string>(
|
||||
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
|
||||
TensorShape({})));
|
||||
TF_CHECK_OK(Run(flr2_, handle, {}, {&y}));
|
||||
opts.remote_execution = true;
|
||||
opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
|
||||
TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
|
||||
test::ExpectTensorEqual<string>(
|
||||
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
|
||||
TensorShape({})));
|
||||
opts.rendezvous->Unref();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -57,6 +58,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
||||
}
|
||||
}
|
||||
|
||||
/* static */
|
||||
string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
|
||||
const AttrSlice& attrs) {
|
||||
const AttrValue* value;
|
||||
@ -66,6 +68,63 @@ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
|
||||
return value->s();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status ProcessFunctionLibraryRuntime::SendTensors(
|
||||
const string& source_device, const string& target_device,
|
||||
const string& key_prefix, gtl::ArraySlice<Tensor> tensors_to_send,
|
||||
const Rendezvous::Args& args, Rendezvous* rendezvous) {
|
||||
std::vector<string> keys;
|
||||
for (int i = 0; i < tensors_to_send.size(); ++i) {
|
||||
string name = strings::StrCat(key_prefix, i);
|
||||
string key = Rendezvous::CreateKey(source_device, i, target_device, name,
|
||||
FrameAndIter(0, 0));
|
||||
keys.push_back(key);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
SendTensorsToRendezvous(rendezvous, args, keys, tensors_to_send));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */
|
||||
void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
|
||||
const string& source_device, const string& target_device,
|
||||
const string& key_prefix, int64 num_tensors, const Rendezvous::Args& args,
|
||||
Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
|
||||
const StatusCallback& done) {
|
||||
std::vector<string> keys;
|
||||
for (int64 i = 0; i < num_tensors; ++i) {
|
||||
string name = strings::StrCat(key_prefix, i);
|
||||
string key = Rendezvous::CreateKey(source_device, i, target_device, name,
|
||||
FrameAndIter(0, 0));
|
||||
keys.push_back(key);
|
||||
}
|
||||
RecvOutputsFromRendezvousAsync(
|
||||
rendezvous, args, keys, received_tensors,
|
||||
[done](const Status& status) { done(status); });
|
||||
}
|
||||
|
||||
Status ProcessFunctionLibraryRuntime::GetDeviceContext(
|
||||
const string& device_name, DeviceContext** device_context) {
|
||||
*device_context = nullptr;
|
||||
FunctionLibraryRuntime* flr = GetFLR(device_name);
|
||||
if (flr == nullptr) {
|
||||
return errors::InvalidArgument("Device name: ", device_name, " not found.");
|
||||
}
|
||||
Device* device = flr->device();
|
||||
string device_type = device->parsed_name().type;
|
||||
if (device_type == "CPU") return Status::OK();
|
||||
if (device_type == "GPU") {
|
||||
auto* dev_info = flr->device()->tensorflow_gpu_device_info();
|
||||
if (dev_info) {
|
||||
*device_context = dev_info->default_context;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return errors::Internal("Device type: ", device_type,
|
||||
" is currently unsupported for remote ",
|
||||
"function executions");
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
|
||||
const string& device_name) {
|
||||
if (flr_map_.find(device_name) == flr_map_.end()) {
|
||||
@ -105,6 +164,7 @@ FunctionLibraryRuntime::LocalHandle
|
||||
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
|
||||
const string& device_name, FunctionLibraryRuntime::Handle handle) {
|
||||
mutex_lock l(mu_);
|
||||
CHECK_LE(handle, function_data_.size());
|
||||
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
||||
function_data_[handle];
|
||||
if (p.first != device_name) {
|
||||
@ -113,6 +173,15 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
|
||||
return p.second;
|
||||
}
|
||||
|
||||
string ProcessFunctionLibraryRuntime::GetDeviceName(
|
||||
FunctionLibraryRuntime::Handle handle) {
|
||||
mutex_lock l(mu_);
|
||||
CHECK_LE(handle, function_data_.size());
|
||||
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
||||
function_data_[handle];
|
||||
return p.first;
|
||||
}
|
||||
|
||||
Status ProcessFunctionLibraryRuntime::Instantiate(
|
||||
const string& function_name, AttrSlice attrs,
|
||||
FunctionLibraryRuntime::Handle* handle) {
|
||||
@ -129,15 +198,58 @@ void ProcessFunctionLibraryRuntime::Run(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
|
||||
if (!opts.remote_execution) {
|
||||
done(errors::InvalidArgument(
|
||||
"ProcessFunctionLibraryRuntime::Run should only be called when there ",
|
||||
"is a remote execution."));
|
||||
return;
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* flr = nullptr;
|
||||
string target_device;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
CHECK_LE(handle, function_data_.size());
|
||||
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
||||
function_data_[handle];
|
||||
target_device = p.first;
|
||||
flr = GetFLR(p.first);
|
||||
}
|
||||
if (flr != nullptr) {
|
||||
return flr->Run(opts, handle, args, rets, std::move(done));
|
||||
auto rendezvous = opts.rendezvous;
|
||||
string source_device = opts.source_device;
|
||||
Rendezvous::Args rendez_args;
|
||||
Status s = GetDeviceContext(source_device, &rendez_args.device_context);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
// Send the args over to the target device.
|
||||
s = SendTensors(source_device, target_device, "arg_", args, rendez_args,
|
||||
rendezvous);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
|
||||
flr->Run(opts, handle, args, remote_rets,
|
||||
[source_device, target_device, rendezvous, remote_rets, rets, done,
|
||||
rendez_args](const Status& status) {
|
||||
if (!status.ok()) {
|
||||
delete remote_rets;
|
||||
done(status);
|
||||
return;
|
||||
}
|
||||
int64 num_returns = remote_rets->size();
|
||||
delete remote_rets;
|
||||
// Now receive the return values from the target.
|
||||
ReceiveTensorsAsync(target_device, source_device, "ret_",
|
||||
num_returns, rendez_args, rendezvous, rets,
|
||||
done);
|
||||
});
|
||||
} else {
|
||||
done(errors::Internal("Could not find device"));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,6 +45,31 @@ class ProcessFunctionLibraryRuntime {
|
||||
// attribute, returns "". Canonicalizes the device name.
|
||||
static string ObtainFunctionTarget(const AttrSlice& attrs);
|
||||
|
||||
// Sends `tensors_to_send` from `source_device` to `target_device` using
|
||||
// `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
|
||||
// Rendezvous. Method takes references on each of the `tensors_to_send`.
|
||||
// Method doesn't block.
|
||||
static Status SendTensors(const string& source_device,
|
||||
const string& target_device,
|
||||
const string& key_prefix,
|
||||
gtl::ArraySlice<Tensor> tensors_to_send,
|
||||
const Rendezvous::Args& args,
|
||||
Rendezvous* rendezvous);
|
||||
|
||||
typedef std::function<void(const Status&)> StatusCallback;
|
||||
|
||||
// Receives `received_tensors` from `target_device` (originally sent from
|
||||
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
|
||||
// keys to be retrieved. Method doesn't block and calls `done` when
|
||||
// `num_tensors` are fetched.
|
||||
static void ReceiveTensorsAsync(const string& source_device,
|
||||
const string& target_device,
|
||||
const string& key_prefix, int64 num_tensors,
|
||||
const Rendezvous::Args& args,
|
||||
Rendezvous* rendezvous,
|
||||
std::vector<Tensor>* received_tensors,
|
||||
const StatusCallback& done);
|
||||
|
||||
static const char kDefaultFLRDevice[];
|
||||
// Returns the FunctionLibraryRuntime for the corresponding device_name.
|
||||
FunctionLibraryRuntime* GetFLR(const string& device_name);
|
||||
@ -85,6 +110,17 @@ class ProcessFunctionLibraryRuntime {
|
||||
FunctionLibraryRuntime::DoneCallback done);
|
||||
|
||||
private:
|
||||
// For a given device_name, returns a DeviceContext for copying
|
||||
// tensors to/from the device.
|
||||
Status GetDeviceContext(const string& device_name,
|
||||
DeviceContext** device_context);
|
||||
|
||||
// Looks up the information for the given `handle` and returns the name
|
||||
// of the device where the function is registered.
|
||||
string GetDeviceName(FunctionLibraryRuntime::Handle handle);
|
||||
|
||||
friend class FunctionLibraryRuntimeImpl;
|
||||
|
||||
mutable mutex mu_;
|
||||
|
||||
// Holds all the function invocations here.
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function_testlib.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -43,10 +44,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts));
|
||||
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
|
||||
}
|
||||
|
||||
Status Run(const string& name, test::function::Attrs attrs,
|
||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
||||
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
|
||||
test::function::Attrs attrs, const std::vector<Tensor>& args,
|
||||
std::vector<Tensor*> rets) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status status = proc_flr_->Instantiate(name, attrs, &handle);
|
||||
if (!status.ok()) {
|
||||
@ -61,7 +64,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
Notification done;
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.runner = &runner;
|
||||
std::vector<Tensor> out;
|
||||
proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
|
||||
@ -86,6 +88,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
|
||||
IntraProcessRendezvous* rendezvous_;
|
||||
};
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
|
||||
@ -99,6 +102,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
|
||||
EXPECT_EQ(flr->device(), devices_[1]);
|
||||
flr = proc_flr_->GetFLR("abc");
|
||||
EXPECT_EQ(flr, nullptr);
|
||||
rendezvous_->Unref();
|
||||
}
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
|
||||
@ -118,69 +122,94 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
|
||||
Init({test::function::XTimesTwo()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.remote_execution = true;
|
||||
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
||||
Tensor y;
|
||||
TF_CHECK_OK(
|
||||
Run("XTimesTwo",
|
||||
Run("XTimesTwo", opts,
|
||||
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
||||
{&y}));
|
||||
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
||||
rendezvous_->Unref();
|
||||
}
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.remote_execution = true;
|
||||
Tensor y;
|
||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}},
|
||||
{}, {&y}));
|
||||
TF_CHECK_OK(Run("FindDevice", opts,
|
||||
{{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
|
||||
test::ExpectTensorEqual<string>(
|
||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
|
||||
TensorShape({})));
|
||||
rendezvous_->Unref();
|
||||
}
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
|
||||
Init({test::function::XTimesTwo(), test::function::XTimesFour()});
|
||||
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.remote_execution = true;
|
||||
Tensor y;
|
||||
TF_CHECK_OK(
|
||||
Run("XTimesTwo",
|
||||
Run("XTimesTwo", opts,
|
||||
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
||||
{&y}));
|
||||
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
||||
TF_CHECK_OK(
|
||||
Run("XTimesFour",
|
||||
Run("XTimesFour", opts,
|
||||
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
||||
{&y}));
|
||||
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
|
||||
rendezvous_->Unref();
|
||||
}
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.remote_execution = true;
|
||||
Tensor y;
|
||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
|
||||
{}, {&y}));
|
||||
TF_CHECK_OK(Run("FindDevice", opts,
|
||||
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
|
||||
test::ExpectTensorEqual<string>(
|
||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
||||
TensorShape({})));
|
||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
|
||||
{}, {&y}));
|
||||
TF_CHECK_OK(Run("FindDevice", opts,
|
||||
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
|
||||
test::ExpectTensorEqual<string>(
|
||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
||||
TensorShape({})));
|
||||
rendezvous_->Unref();
|
||||
}
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.remote_execution = true;
|
||||
Tensor y;
|
||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}},
|
||||
{}, {&y}));
|
||||
TF_CHECK_OK(Run("FindDevice", opts,
|
||||
{{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
|
||||
test::ExpectTensorEqual<string>(
|
||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
|
||||
TensorShape({})));
|
||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
|
||||
{}, {&y}));
|
||||
TF_CHECK_OK(Run("FindDevice", opts,
|
||||
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
|
||||
test::ExpectTensorEqual<string>(
|
||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
||||
TensorShape({})));
|
||||
rendezvous_->Unref();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
119
tensorflow/core/common_runtime/rendezvous_util.cc
Normal file
119
tensorflow/core/common_runtime/rendezvous_util.cc
Normal file
@ -0,0 +1,119 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status SendTensorsToRendezvous(Rendezvous* rendezvous,
|
||||
const Rendezvous::Args& args,
|
||||
const std::vector<string>& keys,
|
||||
gtl::ArraySlice<Tensor> tensors_to_send) {
|
||||
if (keys.size() != tensors_to_send.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"keys and tensors_to_send are not the same size. keys.size() = ",
|
||||
keys.size(), "; tensors_to_send.size() = ", tensors_to_send.size());
|
||||
}
|
||||
Rendezvous::ParsedKey parsed;
|
||||
for (int i = 0; i < keys.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(keys[i], &parsed));
|
||||
TF_RETURN_IF_ERROR(
|
||||
rendezvous->Send(parsed, args, tensors_to_send[i], false));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
|
||||
const Rendezvous::Args& args,
|
||||
const std::vector<string>& keys,
|
||||
std::vector<Tensor>* received_tensors,
|
||||
const StatusCallback& done) {
|
||||
if (keys.empty()) {
|
||||
done(Status::OK());
|
||||
return;
|
||||
}
|
||||
received_tensors->reserve(keys.size());
|
||||
std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> arguments;
|
||||
for (int i = 0; i < keys.size(); ++i) {
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = Rendezvous::ParseKey(keys[i], &parsed);
|
||||
received_tensors->push_back(Tensor());
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
arguments.push_back(
|
||||
std::make_tuple(keys[i], &((*received_tensors)[i]), parsed));
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
mutex mu;
|
||||
int64 done_counter;
|
||||
Status shared_status = Status::OK();
|
||||
} CallState;
|
||||
CallState* call_state = new CallState;
|
||||
call_state->done_counter = keys.size();
|
||||
for (auto& p : arguments) {
|
||||
const string& key = std::get<0>(p);
|
||||
Tensor* val = std::get<1>(p);
|
||||
Rendezvous::ParsedKey parsed = std::get<2>(p);
|
||||
rendezvous->RecvAsync(
|
||||
parsed, args,
|
||||
[val, done, key, call_state](const Status& s,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& v, const bool is_dead) {
|
||||
Status status = s;
|
||||
if (status.ok()) {
|
||||
*val = v;
|
||||
if (is_dead) {
|
||||
status = errors::InvalidArgument("The tensor returned for ", key,
|
||||
" was not valid.");
|
||||
}
|
||||
}
|
||||
call_state->mu.lock();
|
||||
call_state->shared_status.Update(status);
|
||||
call_state->done_counter--;
|
||||
// If we are the last async call to return, call the done callback.
|
||||
if (call_state->done_counter == 0) {
|
||||
const Status& final_status = call_state->shared_status;
|
||||
call_state->mu.unlock();
|
||||
done(final_status);
|
||||
delete call_state;
|
||||
return;
|
||||
}
|
||||
call_state->mu.unlock();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
|
||||
const Rendezvous::Args& args) {
|
||||
// Receives values requested by the caller.
|
||||
Rendezvous::ParsedKey parsed;
|
||||
for (auto& p : *out) {
|
||||
const string& key = p.first;
|
||||
Tensor* val = &p.second;
|
||||
bool is_dead = false;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
|
||||
TF_RETURN_IF_ERROR(rendezvous->Recv(parsed, args, val, &is_dead));
|
||||
if (is_dead) {
|
||||
return errors::InvalidArgument("The tensor returned for ", key,
|
||||
" was not valid.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
44
tensorflow/core/common_runtime/rendezvous_util.h
Normal file
44
tensorflow/core/common_runtime/rendezvous_util.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef std::map<string, Tensor> NamedTensors;
|
||||
typedef std::function<void(const Status&)> StatusCallback;
|
||||
|
||||
// Uses `rendezvous` to send tensors in `in`.
|
||||
Status SendTensorsToRendezvous(Rendezvous* rendezvous,
|
||||
const Rendezvous::Args& args,
|
||||
const std::vector<string>& keys,
|
||||
gtl::ArraySlice<Tensor> tensors_to_send);
|
||||
|
||||
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
|
||||
const Rendezvous::Args& args,
|
||||
const std::vector<string>& keys,
|
||||
std::vector<Tensor>* received_tensors,
|
||||
const StatusCallback& done);
|
||||
|
||||
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
|
||||
const Rendezvous::Args& args);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
|
94
tensorflow/core/common_runtime/rendezvous_util_test.cc
Normal file
94
tensorflow/core/common_runtime/rendezvous_util_test.cc
Normal file
@ -0,0 +1,94 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class RendezvousUtilTest : public ::testing::Test {
|
||||
public:
|
||||
RendezvousUtilTest() { rendez_ = NewLocalRendezvous(); }
|
||||
|
||||
~RendezvousUtilTest() override { rendez_->Unref(); }
|
||||
|
||||
Rendezvous* rendez_;
|
||||
};
|
||||
|
||||
// string -> Tensor<string>
|
||||
Tensor V(const string& content) {
|
||||
Tensor tensor(DT_STRING, TensorShape({}));
|
||||
tensor.scalar<string>()() = content;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Tensor<string> -> string
|
||||
string V(const Tensor& tensor) {
|
||||
CHECK_EQ(tensor.dtype(), DT_STRING);
|
||||
CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
|
||||
return tensor.scalar<string>()();
|
||||
}
|
||||
|
||||
string MakeStringKey(const string& name) {
|
||||
return Rendezvous::CreateKey(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0", 0,
|
||||
"/job:localhost/replica:0/task:0/device:GPU:0", name, FrameAndIter(0, 0));
|
||||
}
|
||||
|
||||
TEST_F(RendezvousUtilTest, SendBeforeRecv) {
|
||||
// Fire off sends before receive the tensors.
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(SendTensorsToRendezvous(
|
||||
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
|
||||
{V("hello1"), V("hello2")}));
|
||||
|
||||
Notification n;
|
||||
std::vector<Tensor> received_keys;
|
||||
RecvOutputsFromRendezvousAsync(
|
||||
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
|
||||
&received_keys, [&n](const Status& status) { n.Notify(); });
|
||||
n.WaitForNotification();
|
||||
|
||||
EXPECT_EQ(2, received_keys.size());
|
||||
EXPECT_EQ("hello1", V(received_keys[0]));
|
||||
EXPECT_EQ("hello2", V(received_keys[1]));
|
||||
}
|
||||
|
||||
TEST_F(RendezvousUtilTest, RecvBeforeSend) {
|
||||
// Fire off recvs, wait for a notification in the callback.
|
||||
Rendezvous::Args args;
|
||||
|
||||
Notification n;
|
||||
std::vector<Tensor> received_keys;
|
||||
RecvOutputsFromRendezvousAsync(
|
||||
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
|
||||
&received_keys, [&n](const Status& status) { n.Notify(); });
|
||||
|
||||
TF_ASSERT_OK(SendTensorsToRendezvous(
|
||||
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
|
||||
{V("hello1"), V("hello2")}));
|
||||
|
||||
n.WaitForNotification();
|
||||
|
||||
EXPECT_EQ(2, received_keys.size());
|
||||
EXPECT_EQ("hello1", V(received_keys[0]));
|
||||
EXPECT_EQ("hello2", V(received_keys[1]));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
@ -321,116 +322,25 @@ Status GraphMgr::DeregisterAll() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphMgr::SendInputsToRendezvous(Rendezvous* rendezvous,
|
||||
const NamedTensors& in) {
|
||||
Rendezvous::ParsedKey parsed;
|
||||
for (const auto& p : in) {
|
||||
const string& key = p.first;
|
||||
const Tensor& val = p.second;
|
||||
|
||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
||||
if (s.ok()) {
|
||||
s = rendezvous->Send(parsed, Rendezvous::Args(), val, false);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphMgr::RecvOutputsFromRendezvous(Rendezvous* rendezvous,
|
||||
NamedTensors* out) {
|
||||
// Receives values requested by the caller.
|
||||
Rendezvous::ParsedKey parsed;
|
||||
for (auto& p : *out) {
|
||||
const string& key = p.first;
|
||||
Tensor* val = &p.second;
|
||||
bool is_dead = false;
|
||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
||||
if (s.ok()) {
|
||||
s = rendezvous->Recv(parsed, Rendezvous::Args(), val, &is_dead);
|
||||
}
|
||||
if (is_dead) {
|
||||
s = errors::InvalidArgument("The tensor returned for ", key,
|
||||
" was not valid.");
|
||||
}
|
||||
if (!s.ok()) return s;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
|
||||
NamedTensors* out,
|
||||
const StatusCallback& done) {
|
||||
if (out->empty()) {
|
||||
done(Status::OK());
|
||||
return;
|
||||
}
|
||||
// We compute the args before calling RecvAsync because we need to ensure that
|
||||
// out isn't being iterated over after done is called, since done deletes out.
|
||||
std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> args;
|
||||
for (auto& p : *out) {
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = Rendezvous::ParseKey(p.first, &parsed);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
args.push_back(std::make_tuple(p.first, &p.second, parsed));
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
mutex mu;
|
||||
int done_counter;
|
||||
Status shared_status = Status::OK();
|
||||
} CallState;
|
||||
CallState* call_state = new CallState;
|
||||
call_state->done_counter = out->size();
|
||||
for (auto& p : args) {
|
||||
const string& key = std::get<0>(p);
|
||||
Tensor* val = std::get<1>(p);
|
||||
Rendezvous::ParsedKey parsed = std::get<2>(p);
|
||||
rendezvous->RecvAsync(
|
||||
parsed, Rendezvous::Args(),
|
||||
[val, done, key, call_state](const Status& s,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& v, const bool is_dead) {
|
||||
Status status = s;
|
||||
if (status.ok()) {
|
||||
*val = v;
|
||||
if (is_dead) {
|
||||
status = errors::InvalidArgument("The tensor returned for ", key,
|
||||
" was not valid.");
|
||||
}
|
||||
}
|
||||
call_state->mu.lock();
|
||||
call_state->shared_status.Update(status);
|
||||
call_state->done_counter--;
|
||||
// If we are the last async call to return, call the done callback.
|
||||
if (call_state->done_counter == 0) {
|
||||
const Status& final_status = call_state->shared_status;
|
||||
call_state->mu.unlock();
|
||||
done(final_status);
|
||||
delete call_state;
|
||||
return;
|
||||
}
|
||||
call_state->mu.unlock();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
|
||||
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||
Status s = SendInputsToRendezvous(rendezvous, in);
|
||||
std::vector<string> keys;
|
||||
std::vector<Tensor> tensors_to_send;
|
||||
keys.reserve(in.size());
|
||||
tensors_to_send.reserve(in.size());
|
||||
for (const auto& p : in) {
|
||||
keys.push_back(p.first);
|
||||
tensors_to_send.push_back(p.second);
|
||||
}
|
||||
Status s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys,
|
||||
tensors_to_send);
|
||||
rendezvous->Unref();
|
||||
return s;
|
||||
}
|
||||
|
||||
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
|
||||
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||
Status s = RecvOutputsFromRendezvous(rendezvous, out);
|
||||
Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
|
||||
rendezvous->Unref();
|
||||
return s;
|
||||
}
|
||||
@ -438,11 +348,24 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
|
||||
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
|
||||
StatusCallback done) {
|
||||
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||
RecvOutputsFromRendezvousAsync(rendezvous, out,
|
||||
[done, rendezvous](const Status s) {
|
||||
rendezvous->Unref();
|
||||
done(s);
|
||||
});
|
||||
std::vector<string> keys;
|
||||
std::vector<Tensor>* received_keys = new std::vector<Tensor>;
|
||||
keys.reserve(out->size());
|
||||
received_keys->reserve(out->size());
|
||||
for (const auto& p : *out) {
|
||||
keys.push_back(p.first);
|
||||
received_keys->push_back(p.second);
|
||||
}
|
||||
RecvOutputsFromRendezvousAsync(
|
||||
rendezvous, Rendezvous::Args(), keys, received_keys,
|
||||
[done, rendezvous, received_keys, out, keys](const Status s) {
|
||||
rendezvous->Unref();
|
||||
for (int i = 0; i < keys.size(); ++i) {
|
||||
(*out)[keys[i]] = (*received_keys)[i];
|
||||
}
|
||||
delete received_keys;
|
||||
done(s);
|
||||
});
|
||||
}
|
||||
|
||||
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||
@ -484,7 +407,16 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||
|
||||
// Sends values specified by the caller.
|
||||
if (s.ok()) {
|
||||
s = SendInputsToRendezvous(rendezvous, in);
|
||||
std::vector<string> keys;
|
||||
std::vector<Tensor> tensors_to_send;
|
||||
keys.reserve(in.size());
|
||||
tensors_to_send.reserve(in.size());
|
||||
for (auto& p : in) {
|
||||
keys.push_back(p.first);
|
||||
tensors_to_send.push_back(p.second);
|
||||
}
|
||||
s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys,
|
||||
tensors_to_send);
|
||||
}
|
||||
|
||||
if (!s.ok()) {
|
||||
|
@ -169,11 +169,6 @@ class GraphMgr {
|
||||
void BuildCostModel(Item* item, StepStatsCollector* collector,
|
||||
CostGraphDef* cost_graph);
|
||||
|
||||
Status SendInputsToRendezvous(Rendezvous* rendezvous, const NamedTensors& in);
|
||||
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out);
|
||||
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, NamedTensors* out,
|
||||
const StatusCallback& done);
|
||||
|
||||
Status InitItem(const string& session, const GraphDef& gdef,
|
||||
const GraphOptions& graph_options,
|
||||
const DebugOptions& debug_options, Item* item);
|
||||
|
@ -426,6 +426,10 @@ class FunctionLibraryRuntime {
|
||||
StepStatsCollector* stats_collector = nullptr;
|
||||
|
||||
std::function<void(std::function<void()>)>* runner = nullptr;
|
||||
|
||||
// Parameters for remote function execution.
|
||||
bool remote_execution = false;
|
||||
string source_device = ""; // Fully specified device name.
|
||||
};
|
||||
typedef std::function<void(const Status&)> DoneCallback;
|
||||
virtual void Run(const Options& opts, Handle handle,
|
||||
|
@ -292,7 +292,8 @@ class RemoteCallOp : public AsyncOpKernel {
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
|
||||
AttrValueMap attr_values = func_->attr();
|
||||
AttrValue v;
|
||||
v.set_s(target->scalar<string>()());
|
||||
const string& target_device = target->scalar<string>()();
|
||||
v.set_s(target_device);
|
||||
AddAttr("_target", v, &attr_values);
|
||||
|
||||
FunctionLibraryRuntime* lib = ctx->function_library();
|
||||
@ -310,6 +311,11 @@ class RemoteCallOp : public AsyncOpKernel {
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.step_id = ctx->step_id();
|
||||
opts.runner = ctx->runner();
|
||||
opts.source_device = lib->device()->name();
|
||||
if (opts.source_device != target_device) {
|
||||
opts.remote_execution = true;
|
||||
}
|
||||
opts.rendezvous = ctx->rendezvous();
|
||||
std::vector<Tensor> args;
|
||||
args.reserve(arguments.size());
|
||||
for (const Tensor& argument : arguments) {
|
||||
@ -334,10 +340,13 @@ class RemoteCallOp : public AsyncOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_CPU), RemoteCallOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_GPU), RemoteCallOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_SYCL), RemoteCallOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("RemoteCall").Device(DEVICE_SYCL).HostMemory("target"), RemoteCallOp);
|
||||
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
} // namespace tensorflow
|
||||
|
@ -500,6 +500,54 @@ class FunctionalOpsTest(test.TestCase):
|
||||
mul = sess.run(remote_op)
|
||||
self.assertEqual(mul, [6])
|
||||
|
||||
def testRemoteFunctionCPUGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def _remote_fn(a, b):
|
||||
return math_ops.multiply(a, b)
|
||||
|
||||
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
|
||||
a = variables.Variable(2, dtype=dtypes.float32)
|
||||
b = variables.Variable(3, dtype=dtypes.float32)
|
||||
|
||||
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
|
||||
remote_op = functional_ops.remote_call(
|
||||
args=[a, b],
|
||||
Tout=[dtypes.float32],
|
||||
f=_remote_fn,
|
||||
target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
mul = sess.run(remote_op)
|
||||
self.assertEqual(mul, 9.0)
|
||||
|
||||
def testRemoteFunctionGPUCPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def _remote_fn(a, b):
|
||||
return math_ops.multiply(a, b)
|
||||
|
||||
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
|
||||
a = variables.Variable(2, dtype=dtypes.float32)
|
||||
b = variables.Variable(3, dtype=dtypes.float32)
|
||||
|
||||
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
|
||||
remote_op = functional_ops.remote_call(
|
||||
args=[a, b],
|
||||
Tout=[dtypes.float32],
|
||||
f=_remote_fn,
|
||||
target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
mul = sess.run(remote_op)
|
||||
self.assertEqual(mul, 9.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user