Using rendezvous manager to pass args / rets between devices during function remote execution. This enables CPU->GPU remote device executions now.
PiperOrigin-RevId: 168038285
This commit is contained in:
parent
82cc6529f4
commit
450c3b5626
tensorflow
contrib/data/python/kernel_tests
core
BUILD
common_runtime
function.ccfunction_test.ccprocess_function_library_runtime.ccprocess_function_library_runtime.hprocess_function_library_runtime_test.ccrendezvous_util.ccrendezvous_util.hrendezvous_util_test.cc
distributed_runtime
framework
kernels
python/kernel_tests
@ -24,6 +24,8 @@ py_test(
|
|||||||
"//tensorflow/python:functional_ops",
|
"//tensorflow/python:functional_ops",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:parsing_ops",
|
||||||
|
"//tensorflow/python:script_ops",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
|
@ -27,10 +27,13 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import function
|
from tensorflow.python.framework import function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import parsing_ops
|
||||||
|
from tensorflow.python.ops import script_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
|
|
||||||
@ -420,7 +423,7 @@ class IteratorTest(test.TestCase):
|
|||||||
|
|
||||||
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
|
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
|
||||||
worker_config = config_pb2.ConfigProto()
|
worker_config = config_pb2.ConfigProto()
|
||||||
worker_config.device_count["CPU"] = 2
|
worker_config.device_count["CPU"] = 3
|
||||||
|
|
||||||
with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
|
with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
|
||||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||||
@ -448,12 +451,12 @@ class IteratorTest(test.TestCase):
|
|||||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
||||||
})
|
})
|
||||||
self.assertEqual(elem, [1])
|
self.assertEqual(elem, [1])
|
||||||
# Fails when target is cpu:0 where the resource is not located.
|
# Fails when target is cpu:2 where the resource is not located.
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(
|
sess.run(
|
||||||
remote_op,
|
remote_op,
|
||||||
feed_dict={
|
feed_dict={
|
||||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:2"
|
||||||
})
|
})
|
||||||
elem = sess.run(
|
elem = sess.run(
|
||||||
remote_op,
|
remote_op,
|
||||||
@ -474,6 +477,61 @@ class IteratorTest(test.TestCase):
|
|||||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
|
||||||
|
if not test_util.is_gpu_available():
|
||||||
|
self.skipTest("No GPU available")
|
||||||
|
|
||||||
|
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
|
||||||
|
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||||
|
iterator_3 = dataset_3.make_one_shot_iterator()
|
||||||
|
iterator_3_handle = iterator_3.string_handle()
|
||||||
|
|
||||||
|
def _encode_raw(byte_array):
|
||||||
|
return "".join([chr(item) for item in byte_array])
|
||||||
|
|
||||||
|
@function.Defun(dtypes.uint8)
|
||||||
|
def _remote_fn(h):
|
||||||
|
handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
|
||||||
|
remote_iterator = dataset_ops.Iterator.from_string_handle(
|
||||||
|
handle, dataset_3.output_types, dataset_3.output_shapes)
|
||||||
|
return remote_iterator.get_next()
|
||||||
|
|
||||||
|
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
|
||||||
|
target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||||
|
iterator_3_handle_uint8 = parsing_ops.decode_raw(
|
||||||
|
bytes=iterator_3_handle, out_type=dtypes.uint8)
|
||||||
|
remote_op = functional_ops.remote_call(
|
||||||
|
args=[iterator_3_handle_uint8],
|
||||||
|
Tout=[dtypes.int32],
|
||||||
|
f=_remote_fn,
|
||||||
|
target=target_placeholder)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
elem = sess.run(
|
||||||
|
remote_op,
|
||||||
|
feed_dict={
|
||||||
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||||
|
})
|
||||||
|
self.assertEqual(elem, [1])
|
||||||
|
elem = sess.run(
|
||||||
|
remote_op,
|
||||||
|
feed_dict={
|
||||||
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||||
|
})
|
||||||
|
self.assertEqual(elem, [2])
|
||||||
|
elem = sess.run(
|
||||||
|
remote_op,
|
||||||
|
feed_dict={
|
||||||
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||||
|
})
|
||||||
|
self.assertEqual(elem, [3])
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(
|
||||||
|
remote_op,
|
||||||
|
feed_dict={
|
||||||
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -1785,6 +1785,7 @@ tf_cuda_library(
|
|||||||
"common_runtime/process_util.cc",
|
"common_runtime/process_util.cc",
|
||||||
"common_runtime/renamed_device.cc",
|
"common_runtime/renamed_device.cc",
|
||||||
"common_runtime/rendezvous_mgr.cc",
|
"common_runtime/rendezvous_mgr.cc",
|
||||||
|
"common_runtime/rendezvous_util.cc",
|
||||||
"common_runtime/resource_variable_read_optimizer.cc",
|
"common_runtime/resource_variable_read_optimizer.cc",
|
||||||
"common_runtime/session.cc",
|
"common_runtime/session.cc",
|
||||||
"common_runtime/session_factory.cc",
|
"common_runtime/session_factory.cc",
|
||||||
@ -1828,6 +1829,7 @@ tf_cuda_library(
|
|||||||
"common_runtime/profile_handler.h",
|
"common_runtime/profile_handler.h",
|
||||||
"common_runtime/renamed_device.h",
|
"common_runtime/renamed_device.h",
|
||||||
"common_runtime/rendezvous_mgr.h",
|
"common_runtime/rendezvous_mgr.h",
|
||||||
|
"common_runtime/rendezvous_util.h",
|
||||||
"common_runtime/session_factory.h",
|
"common_runtime/session_factory.h",
|
||||||
"common_runtime/graph_execution_state.h",
|
"common_runtime/graph_execution_state.h",
|
||||||
"common_runtime/placer.h",
|
"common_runtime/placer.h",
|
||||||
@ -2669,29 +2671,29 @@ tf_cc_test(
|
|||||||
srcs = ["common_runtime/process_function_library_runtime_test.cc"],
|
srcs = ["common_runtime/process_function_library_runtime_test.cc"],
|
||||||
linkstatic = tf_kernel_tests_linkstatic(),
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
deps = [
|
deps = [
|
||||||
":core",
|
|
||||||
":core_cpu",
|
":core_cpu",
|
||||||
":core_cpu_internal",
|
":core_cpu_internal",
|
||||||
":direct_session_internal",
|
|
||||||
":framework",
|
":framework",
|
||||||
":framework_internal",
|
|
||||||
":lib",
|
|
||||||
":lib_internal",
|
|
||||||
":ops",
|
|
||||||
":protos_all_cc",
|
|
||||||
":test",
|
":test",
|
||||||
":test_main",
|
":test_main",
|
||||||
":testlib",
|
":testlib",
|
||||||
"//tensorflow/cc:cc_ops",
|
|
||||||
"//tensorflow/cc:cc_ops_internal",
|
|
||||||
"//tensorflow/cc:function_ops",
|
"//tensorflow/cc:function_ops",
|
||||||
"//tensorflow/cc:functional_ops",
|
|
||||||
"//tensorflow/core/kernels:cast_op",
|
"//tensorflow/core/kernels:cast_op",
|
||||||
"//tensorflow/core/kernels:cwise_op",
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
"//tensorflow/core/kernels:function_ops",
|
"//tensorflow/core/kernels:function_ops",
|
||||||
"//tensorflow/core/kernels:matmul_op",
|
],
|
||||||
"//tensorflow/core/kernels:shape_ops",
|
)
|
||||||
"//third_party/eigen3",
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "common_runtime_rendezvous_util_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["common_runtime/rendezvous_util_test.cc"],
|
||||||
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
|
deps = [
|
||||||
|
":core_cpu_internal",
|
||||||
|
":lib",
|
||||||
|
":test",
|
||||||
|
":test_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -213,6 +213,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
|||||||
FunctionBody** g_body);
|
FunctionBody** g_body);
|
||||||
bool IsLocalTarget(const AttrSlice& attrs);
|
bool IsLocalTarget(const AttrSlice& attrs);
|
||||||
AttrValueMap FixAttrs(const AttrSlice& attrs);
|
AttrValueMap FixAttrs(const AttrSlice& attrs);
|
||||||
|
void RunRemote(const Options& opts, Handle handle,
|
||||||
|
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
||||||
|
Executor::Args* exec_args, Item* item, DoneCallback done);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
|
TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
|
||||||
};
|
};
|
||||||
@ -557,52 +560,130 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
|
||||||
|
gtl::ArraySlice<Tensor> args,
|
||||||
|
std::vector<Tensor>* rets,
|
||||||
|
Executor::Args* exec_args,
|
||||||
|
Item* item, DoneCallback done) {
|
||||||
|
FunctionCallFrame* frame = exec_args->call_frame;
|
||||||
|
string target_device = parent_->GetDeviceName(handle);
|
||||||
|
string source_device = opts.source_device;
|
||||||
|
Rendezvous* rendezvous = opts.rendezvous;
|
||||||
|
// TODO(rohanj): Handle alloc_attrs in Rendezvous::Args.
|
||||||
|
Rendezvous::Args rendez_args;
|
||||||
|
Status s =
|
||||||
|
parent_->GetDeviceContext(target_device, &rendez_args.device_context);
|
||||||
|
if (!s.ok()) {
|
||||||
|
delete frame;
|
||||||
|
delete exec_args;
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The ProcFLR sends the arguments to the function from the source_device to
|
||||||
|
// the target_device. So here we receive those arguments. Similarly, when the
|
||||||
|
// computation is done and stored in *rets, we send the return values back
|
||||||
|
// to the source_device (caller) so that the ProcFLR can receive them later.
|
||||||
|
std::vector<Tensor>* remote_args = new std::vector<Tensor>;
|
||||||
|
ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
|
||||||
|
source_device, target_device, "arg_", args.size(), rendez_args,
|
||||||
|
rendezvous, remote_args,
|
||||||
|
[frame, remote_args, item, source_device, target_device, rendezvous,
|
||||||
|
rendez_args, rets, done, exec_args](const Status& status) {
|
||||||
|
Status s = status;
|
||||||
|
s = frame->SetArgs(*remote_args);
|
||||||
|
if (!s.ok()) {
|
||||||
|
delete frame;
|
||||||
|
delete remote_args;
|
||||||
|
delete exec_args;
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
item->exec->RunAsync(
|
||||||
|
*exec_args,
|
||||||
|
[item, frame, rets, done, source_device, target_device, rendezvous,
|
||||||
|
rendez_args, remote_args, exec_args](const Status& status) {
|
||||||
|
item->Unref();
|
||||||
|
Status s = status;
|
||||||
|
if (s.ok()) {
|
||||||
|
s = frame->ConsumeRetvals(rets);
|
||||||
|
}
|
||||||
|
delete frame;
|
||||||
|
if (!s.ok()) {
|
||||||
|
delete remote_args;
|
||||||
|
delete exec_args;
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
s = ProcessFunctionLibraryRuntime::SendTensors(
|
||||||
|
target_device, source_device, "ret_", *rets, rendez_args,
|
||||||
|
rendezvous);
|
||||||
|
delete remote_args;
|
||||||
|
delete exec_args;
|
||||||
|
done(s);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||||
gtl::ArraySlice<Tensor> args,
|
gtl::ArraySlice<Tensor> args,
|
||||||
std::vector<Tensor>* rets,
|
std::vector<Tensor>* rets,
|
||||||
DoneCallback done) {
|
DoneCallback done) {
|
||||||
if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
|
if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
|
||||||
return done(errors::Cancelled(""));
|
done(errors::Cancelled(""));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
|
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
|
||||||
return parent_->Run(opts, handle, args, rets, done);
|
parent_->Run(opts, handle, args, rets, done);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
const FunctionBody* fbody = GetFunctionBody(handle);
|
const FunctionBody* fbody = GetFunctionBody(handle);
|
||||||
FunctionCallFrame* frame =
|
FunctionCallFrame* frame =
|
||||||
new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
|
new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
|
||||||
Status s = frame->SetArgs(args);
|
|
||||||
if (!s.ok()) {
|
|
||||||
delete frame;
|
|
||||||
return done(s);
|
|
||||||
}
|
|
||||||
Item* item = nullptr;
|
Item* item = nullptr;
|
||||||
s = GetOrCreateItem(handle, &item);
|
Status s = GetOrCreateItem(handle, &item);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
delete frame;
|
delete frame;
|
||||||
return done(s);
|
done(s);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
DCHECK(opts.runner != nullptr);
|
DCHECK(opts.runner != nullptr);
|
||||||
|
|
||||||
Executor::Args exec_args;
|
Executor::Args* exec_args = new Executor::Args;
|
||||||
// Inherit the step_id from the caller.
|
// Inherit the step_id from the caller.
|
||||||
exec_args.step_id = opts.step_id;
|
exec_args->step_id = opts.step_id;
|
||||||
exec_args.rendezvous = opts.rendezvous;
|
exec_args->rendezvous = opts.rendezvous;
|
||||||
exec_args.stats_collector = opts.stats_collector;
|
exec_args->stats_collector = opts.stats_collector;
|
||||||
exec_args.call_frame = frame;
|
exec_args->call_frame = frame;
|
||||||
exec_args.cancellation_manager = opts.cancellation_manager;
|
exec_args->cancellation_manager = opts.cancellation_manager;
|
||||||
exec_args.step_container = opts.step_container;
|
exec_args->step_container = opts.step_container;
|
||||||
exec_args.runner = *opts.runner;
|
exec_args->runner = *opts.runner;
|
||||||
|
|
||||||
|
if (opts.remote_execution) {
|
||||||
|
RunRemote(opts, handle, args, rets, exec_args, item, done);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
s = frame->SetArgs(args);
|
||||||
|
if (!s.ok()) {
|
||||||
|
delete frame;
|
||||||
|
delete exec_args;
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
item->exec->RunAsync(
|
item->exec->RunAsync(
|
||||||
// Executor args
|
// Executor args
|
||||||
exec_args,
|
*exec_args,
|
||||||
// Done callback.
|
// Done callback.
|
||||||
[item, frame, rets, done](const Status& status) {
|
[item, frame, rets, done, exec_args](const Status& status) {
|
||||||
item->Unref();
|
item->Unref();
|
||||||
Status s = status;
|
Status s = status;
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
s = frame->GetRetvals(rets);
|
s = frame->ConsumeRetvals(rets);
|
||||||
}
|
}
|
||||||
delete frame;
|
delete frame;
|
||||||
|
delete exec_args;
|
||||||
done(s);
|
done(s);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/executor.h"
|
#include "tensorflow/core/common_runtime/executor.h"
|
||||||
#include "tensorflow/core/common_runtime/function_testlib.h"
|
#include "tensorflow/core/common_runtime/function_testlib.h"
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/function_testlib.h"
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
@ -155,6 +156,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
|
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
|
||||||
|
FunctionLibraryRuntime::Options opts,
|
||||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
||||||
std::atomic<int32> call_count(0);
|
std::atomic<int32> call_count(0);
|
||||||
std::function<void(std::function<void()>)> runner =
|
std::function<void(std::function<void()>)> runner =
|
||||||
@ -164,7 +166,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Notification done;
|
Notification done;
|
||||||
FunctionLibraryRuntime::Options opts;
|
|
||||||
opts.runner = &runner;
|
opts.runner = &runner;
|
||||||
std::vector<Tensor> out;
|
std::vector<Tensor> out;
|
||||||
Status status;
|
Status status;
|
||||||
@ -205,7 +206,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
return Run(flr, handle, args, std::move(rets));
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
return Run(flr, handle, opts, args, std::move(rets));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
|
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
|
||||||
@ -963,15 +965,21 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
|
|||||||
{{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle));
|
{{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle));
|
||||||
|
|
||||||
Tensor y;
|
Tensor y;
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
|
||||||
|
opts.source_device = "/device:CPU:1";
|
||||||
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
|
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
|
||||||
TF_CHECK_OK(Run(flr1_, handle, {}, {&y}));
|
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
|
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
TF_CHECK_OK(Run(flr2_, handle, {}, {&y}));
|
opts.remote_execution = true;
|
||||||
|
opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
|
||||||
|
TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
|
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
|
opts.rendezvous->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -57,6 +58,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */
|
||||||
string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
|
string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
|
||||||
const AttrSlice& attrs) {
|
const AttrSlice& attrs) {
|
||||||
const AttrValue* value;
|
const AttrValue* value;
|
||||||
@ -66,6 +68,63 @@ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
|
|||||||
return value->s();
|
return value->s();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
Status ProcessFunctionLibraryRuntime::SendTensors(
|
||||||
|
const string& source_device, const string& target_device,
|
||||||
|
const string& key_prefix, gtl::ArraySlice<Tensor> tensors_to_send,
|
||||||
|
const Rendezvous::Args& args, Rendezvous* rendezvous) {
|
||||||
|
std::vector<string> keys;
|
||||||
|
for (int i = 0; i < tensors_to_send.size(); ++i) {
|
||||||
|
string name = strings::StrCat(key_prefix, i);
|
||||||
|
string key = Rendezvous::CreateKey(source_device, i, target_device, name,
|
||||||
|
FrameAndIter(0, 0));
|
||||||
|
keys.push_back(key);
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
SendTensorsToRendezvous(rendezvous, args, keys, tensors_to_send));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
|
||||||
|
const string& source_device, const string& target_device,
|
||||||
|
const string& key_prefix, int64 num_tensors, const Rendezvous::Args& args,
|
||||||
|
Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
|
||||||
|
const StatusCallback& done) {
|
||||||
|
std::vector<string> keys;
|
||||||
|
for (int64 i = 0; i < num_tensors; ++i) {
|
||||||
|
string name = strings::StrCat(key_prefix, i);
|
||||||
|
string key = Rendezvous::CreateKey(source_device, i, target_device, name,
|
||||||
|
FrameAndIter(0, 0));
|
||||||
|
keys.push_back(key);
|
||||||
|
}
|
||||||
|
RecvOutputsFromRendezvousAsync(
|
||||||
|
rendezvous, args, keys, received_tensors,
|
||||||
|
[done](const Status& status) { done(status); });
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ProcessFunctionLibraryRuntime::GetDeviceContext(
|
||||||
|
const string& device_name, DeviceContext** device_context) {
|
||||||
|
*device_context = nullptr;
|
||||||
|
FunctionLibraryRuntime* flr = GetFLR(device_name);
|
||||||
|
if (flr == nullptr) {
|
||||||
|
return errors::InvalidArgument("Device name: ", device_name, " not found.");
|
||||||
|
}
|
||||||
|
Device* device = flr->device();
|
||||||
|
string device_type = device->parsed_name().type;
|
||||||
|
if (device_type == "CPU") return Status::OK();
|
||||||
|
if (device_type == "GPU") {
|
||||||
|
auto* dev_info = flr->device()->tensorflow_gpu_device_info();
|
||||||
|
if (dev_info) {
|
||||||
|
*device_context = dev_info->default_context;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors::Internal("Device type: ", device_type,
|
||||||
|
" is currently unsupported for remote ",
|
||||||
|
"function executions");
|
||||||
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
|
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
|
||||||
const string& device_name) {
|
const string& device_name) {
|
||||||
if (flr_map_.find(device_name) == flr_map_.end()) {
|
if (flr_map_.find(device_name) == flr_map_.end()) {
|
||||||
@ -105,6 +164,7 @@ FunctionLibraryRuntime::LocalHandle
|
|||||||
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
|
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
|
||||||
const string& device_name, FunctionLibraryRuntime::Handle handle) {
|
const string& device_name, FunctionLibraryRuntime::Handle handle) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
CHECK_LE(handle, function_data_.size());
|
||||||
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
||||||
function_data_[handle];
|
function_data_[handle];
|
||||||
if (p.first != device_name) {
|
if (p.first != device_name) {
|
||||||
@ -113,6 +173,15 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
|
|||||||
return p.second;
|
return p.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string ProcessFunctionLibraryRuntime::GetDeviceName(
|
||||||
|
FunctionLibraryRuntime::Handle handle) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
CHECK_LE(handle, function_data_.size());
|
||||||
|
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
||||||
|
function_data_[handle];
|
||||||
|
return p.first;
|
||||||
|
}
|
||||||
|
|
||||||
Status ProcessFunctionLibraryRuntime::Instantiate(
|
Status ProcessFunctionLibraryRuntime::Instantiate(
|
||||||
const string& function_name, AttrSlice attrs,
|
const string& function_name, AttrSlice attrs,
|
||||||
FunctionLibraryRuntime::Handle* handle) {
|
FunctionLibraryRuntime::Handle* handle) {
|
||||||
@ -129,15 +198,58 @@ void ProcessFunctionLibraryRuntime::Run(
|
|||||||
const FunctionLibraryRuntime::Options& opts,
|
const FunctionLibraryRuntime::Options& opts,
|
||||||
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
|
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
|
||||||
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
|
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
|
||||||
|
if (!opts.remote_execution) {
|
||||||
|
done(errors::InvalidArgument(
|
||||||
|
"ProcessFunctionLibraryRuntime::Run should only be called when there ",
|
||||||
|
"is a remote execution."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime* flr = nullptr;
|
FunctionLibraryRuntime* flr = nullptr;
|
||||||
|
string target_device;
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
CHECK_LE(handle, function_data_.size());
|
||||||
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
||||||
function_data_[handle];
|
function_data_[handle];
|
||||||
|
target_device = p.first;
|
||||||
flr = GetFLR(p.first);
|
flr = GetFLR(p.first);
|
||||||
}
|
}
|
||||||
if (flr != nullptr) {
|
if (flr != nullptr) {
|
||||||
return flr->Run(opts, handle, args, rets, std::move(done));
|
auto rendezvous = opts.rendezvous;
|
||||||
|
string source_device = opts.source_device;
|
||||||
|
Rendezvous::Args rendez_args;
|
||||||
|
Status s = GetDeviceContext(source_device, &rendez_args.device_context);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Send the args over to the target device.
|
||||||
|
s = SendTensors(source_device, target_device, "arg_", args, rendez_args,
|
||||||
|
rendezvous);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
|
||||||
|
flr->Run(opts, handle, args, remote_rets,
|
||||||
|
[source_device, target_device, rendezvous, remote_rets, rets, done,
|
||||||
|
rendez_args](const Status& status) {
|
||||||
|
if (!status.ok()) {
|
||||||
|
delete remote_rets;
|
||||||
|
done(status);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int64 num_returns = remote_rets->size();
|
||||||
|
delete remote_rets;
|
||||||
|
// Now receive the return values from the target.
|
||||||
|
ReceiveTensorsAsync(target_device, source_device, "ret_",
|
||||||
|
num_returns, rendez_args, rendezvous, rets,
|
||||||
|
done);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
done(errors::Internal("Could not find device"));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,6 +45,31 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
// attribute, returns "". Canonicalizes the device name.
|
// attribute, returns "". Canonicalizes the device name.
|
||||||
static string ObtainFunctionTarget(const AttrSlice& attrs);
|
static string ObtainFunctionTarget(const AttrSlice& attrs);
|
||||||
|
|
||||||
|
// Sends `tensors_to_send` from `source_device` to `target_device` using
|
||||||
|
// `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
|
||||||
|
// Rendezvous. Method takes references on each of the `tensors_to_send`.
|
||||||
|
// Method doesn't block.
|
||||||
|
static Status SendTensors(const string& source_device,
|
||||||
|
const string& target_device,
|
||||||
|
const string& key_prefix,
|
||||||
|
gtl::ArraySlice<Tensor> tensors_to_send,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
Rendezvous* rendezvous);
|
||||||
|
|
||||||
|
typedef std::function<void(const Status&)> StatusCallback;
|
||||||
|
|
||||||
|
// Receives `received_tensors` from `target_device` (originally sent from
|
||||||
|
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
|
||||||
|
// keys to be retrieved. Method doesn't block and calls `done` when
|
||||||
|
// `num_tensors` are fetched.
|
||||||
|
static void ReceiveTensorsAsync(const string& source_device,
|
||||||
|
const string& target_device,
|
||||||
|
const string& key_prefix, int64 num_tensors,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
Rendezvous* rendezvous,
|
||||||
|
std::vector<Tensor>* received_tensors,
|
||||||
|
const StatusCallback& done);
|
||||||
|
|
||||||
static const char kDefaultFLRDevice[];
|
static const char kDefaultFLRDevice[];
|
||||||
// Returns the FunctionLibraryRuntime for the corresponding device_name.
|
// Returns the FunctionLibraryRuntime for the corresponding device_name.
|
||||||
FunctionLibraryRuntime* GetFLR(const string& device_name);
|
FunctionLibraryRuntime* GetFLR(const string& device_name);
|
||||||
@ -85,6 +110,17 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
FunctionLibraryRuntime::DoneCallback done);
|
FunctionLibraryRuntime::DoneCallback done);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// For a given device_name, returns a DeviceContext for copying
|
||||||
|
// tensors to/from the device.
|
||||||
|
Status GetDeviceContext(const string& device_name,
|
||||||
|
DeviceContext** device_context);
|
||||||
|
|
||||||
|
// Looks up the information for the given `handle` and returns the name
|
||||||
|
// of the device where the function is registered.
|
||||||
|
string GetDeviceName(FunctionLibraryRuntime::Handle handle);
|
||||||
|
|
||||||
|
friend class FunctionLibraryRuntimeImpl;
|
||||||
|
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
|
|
||||||
// Holds all the function invocations here.
|
// Holds all the function invocations here.
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/function_testlib.h"
|
#include "tensorflow/core/common_runtime/function_testlib.h"
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/framework/function_testlib.h"
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
@ -43,10 +44,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
|
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
|
||||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||||
opts));
|
opts));
|
||||||
|
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Run(const string& name, test::function::Attrs attrs,
|
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
|
||||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
test::function::Attrs attrs, const std::vector<Tensor>& args,
|
||||||
|
std::vector<Tensor*> rets) {
|
||||||
FunctionLibraryRuntime::Handle handle;
|
FunctionLibraryRuntime::Handle handle;
|
||||||
Status status = proc_flr_->Instantiate(name, attrs, &handle);
|
Status status = proc_flr_->Instantiate(name, attrs, &handle);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
@ -61,7 +64,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Notification done;
|
Notification done;
|
||||||
FunctionLibraryRuntime::Options opts;
|
|
||||||
opts.runner = &runner;
|
opts.runner = &runner;
|
||||||
std::vector<Tensor> out;
|
std::vector<Tensor> out;
|
||||||
proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
|
proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
|
||||||
@ -86,6 +88,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||||
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
|
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
|
||||||
|
IntraProcessRendezvous* rendezvous_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
|
||||||
@ -99,6 +102,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
|
|||||||
EXPECT_EQ(flr->device(), devices_[1]);
|
EXPECT_EQ(flr->device(), devices_[1]);
|
||||||
flr = proc_flr_->GetFLR("abc");
|
flr = proc_flr_->GetFLR("abc");
|
||||||
EXPECT_EQ(flr, nullptr);
|
EXPECT_EQ(flr, nullptr);
|
||||||
|
rendezvous_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
|
||||||
@ -118,69 +122,94 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
|
|||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
|
||||||
Init({test::function::XTimesTwo()});
|
Init({test::function::XTimesTwo()});
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||||
|
opts.rendezvous = rendezvous_;
|
||||||
|
opts.remote_execution = true;
|
||||||
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
||||||
Tensor y;
|
Tensor y;
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
Run("XTimesTwo",
|
Run("XTimesTwo", opts,
|
||||||
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
||||||
{&y}));
|
{&y}));
|
||||||
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
||||||
|
rendezvous_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
|
||||||
Init({test::function::FindDevice()});
|
Init({test::function::FindDevice()});
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||||
|
opts.rendezvous = rendezvous_;
|
||||||
|
opts.remote_execution = true;
|
||||||
Tensor y;
|
Tensor y;
|
||||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}},
|
TF_CHECK_OK(Run("FindDevice", opts,
|
||||||
{}, {&y}));
|
{{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
|
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
|
rendezvous_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
|
||||||
Init({test::function::XTimesTwo(), test::function::XTimesFour()});
|
Init({test::function::XTimesTwo(), test::function::XTimesFour()});
|
||||||
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||||
|
opts.rendezvous = rendezvous_;
|
||||||
|
opts.remote_execution = true;
|
||||||
Tensor y;
|
Tensor y;
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
Run("XTimesTwo",
|
Run("XTimesTwo", opts,
|
||||||
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
||||||
{&y}));
|
{&y}));
|
||||||
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
Run("XTimesFour",
|
Run("XTimesFour", opts,
|
||||||
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
||||||
{&y}));
|
{&y}));
|
||||||
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
|
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
|
||||||
|
rendezvous_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
|
||||||
Init({test::function::FindDevice()});
|
Init({test::function::FindDevice()});
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||||
|
opts.rendezvous = rendezvous_;
|
||||||
|
opts.remote_execution = true;
|
||||||
Tensor y;
|
Tensor y;
|
||||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
|
TF_CHECK_OK(Run("FindDevice", opts,
|
||||||
{}, {&y}));
|
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
|
TF_CHECK_OK(Run("FindDevice", opts,
|
||||||
{}, {&y}));
|
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
|
rendezvous_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
|
||||||
Init({test::function::FindDevice()});
|
Init({test::function::FindDevice()});
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||||
|
opts.rendezvous = rendezvous_;
|
||||||
|
opts.remote_execution = true;
|
||||||
Tensor y;
|
Tensor y;
|
||||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}},
|
TF_CHECK_OK(Run("FindDevice", opts,
|
||||||
{}, {&y}));
|
{{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
|
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
|
TF_CHECK_OK(Run("FindDevice", opts,
|
||||||
{}, {&y}));
|
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
|
rendezvous_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
119
tensorflow/core/common_runtime/rendezvous_util.cc
Normal file
119
tensorflow/core/common_runtime/rendezvous_util.cc
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Status SendTensorsToRendezvous(Rendezvous* rendezvous,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
const std::vector<string>& keys,
|
||||||
|
gtl::ArraySlice<Tensor> tensors_to_send) {
|
||||||
|
if (keys.size() != tensors_to_send.size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"keys and tensors_to_send are not the same size. keys.size() = ",
|
||||||
|
keys.size(), "; tensors_to_send.size() = ", tensors_to_send.size());
|
||||||
|
}
|
||||||
|
Rendezvous::ParsedKey parsed;
|
||||||
|
for (int i = 0; i < keys.size(); ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(keys[i], &parsed));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
rendezvous->Send(parsed, args, tensors_to_send[i], false));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
const std::vector<string>& keys,
|
||||||
|
std::vector<Tensor>* received_tensors,
|
||||||
|
const StatusCallback& done) {
|
||||||
|
if (keys.empty()) {
|
||||||
|
done(Status::OK());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
received_tensors->reserve(keys.size());
|
||||||
|
std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> arguments;
|
||||||
|
for (int i = 0; i < keys.size(); ++i) {
|
||||||
|
Rendezvous::ParsedKey parsed;
|
||||||
|
Status s = Rendezvous::ParseKey(keys[i], &parsed);
|
||||||
|
received_tensors->push_back(Tensor());
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
arguments.push_back(
|
||||||
|
std::make_tuple(keys[i], &((*received_tensors)[i]), parsed));
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
mutex mu;
|
||||||
|
int64 done_counter;
|
||||||
|
Status shared_status = Status::OK();
|
||||||
|
} CallState;
|
||||||
|
CallState* call_state = new CallState;
|
||||||
|
call_state->done_counter = keys.size();
|
||||||
|
for (auto& p : arguments) {
|
||||||
|
const string& key = std::get<0>(p);
|
||||||
|
Tensor* val = std::get<1>(p);
|
||||||
|
Rendezvous::ParsedKey parsed = std::get<2>(p);
|
||||||
|
rendezvous->RecvAsync(
|
||||||
|
parsed, args,
|
||||||
|
[val, done, key, call_state](const Status& s,
|
||||||
|
const Rendezvous::Args& send_args,
|
||||||
|
const Rendezvous::Args& recv_args,
|
||||||
|
const Tensor& v, const bool is_dead) {
|
||||||
|
Status status = s;
|
||||||
|
if (status.ok()) {
|
||||||
|
*val = v;
|
||||||
|
if (is_dead) {
|
||||||
|
status = errors::InvalidArgument("The tensor returned for ", key,
|
||||||
|
" was not valid.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
call_state->mu.lock();
|
||||||
|
call_state->shared_status.Update(status);
|
||||||
|
call_state->done_counter--;
|
||||||
|
// If we are the last async call to return, call the done callback.
|
||||||
|
if (call_state->done_counter == 0) {
|
||||||
|
const Status& final_status = call_state->shared_status;
|
||||||
|
call_state->mu.unlock();
|
||||||
|
done(final_status);
|
||||||
|
delete call_state;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
call_state->mu.unlock();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
|
||||||
|
const Rendezvous::Args& args) {
|
||||||
|
// Receives values requested by the caller.
|
||||||
|
Rendezvous::ParsedKey parsed;
|
||||||
|
for (auto& p : *out) {
|
||||||
|
const string& key = p.first;
|
||||||
|
Tensor* val = &p.second;
|
||||||
|
bool is_dead = false;
|
||||||
|
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
|
||||||
|
TF_RETURN_IF_ERROR(rendezvous->Recv(parsed, args, val, &is_dead));
|
||||||
|
if (is_dead) {
|
||||||
|
return errors::InvalidArgument("The tensor returned for ", key,
|
||||||
|
" was not valid.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
44
tensorflow/core/common_runtime/rendezvous_util.h
Normal file
44
tensorflow/core/common_runtime/rendezvous_util.h
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/rendezvous.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef std::map<string, Tensor> NamedTensors;
|
||||||
|
typedef std::function<void(const Status&)> StatusCallback;
|
||||||
|
|
||||||
|
// Uses `rendezvous` to send tensors in `in`.
|
||||||
|
Status SendTensorsToRendezvous(Rendezvous* rendezvous,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
const std::vector<string>& keys,
|
||||||
|
gtl::ArraySlice<Tensor> tensors_to_send);
|
||||||
|
|
||||||
|
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
const std::vector<string>& keys,
|
||||||
|
std::vector<Tensor>* received_tensors,
|
||||||
|
const StatusCallback& done);
|
||||||
|
|
||||||
|
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
|
||||||
|
const Rendezvous::Args& args);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
|
94
tensorflow/core/common_runtime/rendezvous_util_test.cc
Normal file
94
tensorflow/core/common_runtime/rendezvous_util_test.cc
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/notification.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class RendezvousUtilTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
RendezvousUtilTest() { rendez_ = NewLocalRendezvous(); }
|
||||||
|
|
||||||
|
~RendezvousUtilTest() override { rendez_->Unref(); }
|
||||||
|
|
||||||
|
Rendezvous* rendez_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// string -> Tensor<string>
|
||||||
|
Tensor V(const string& content) {
|
||||||
|
Tensor tensor(DT_STRING, TensorShape({}));
|
||||||
|
tensor.scalar<string>()() = content;
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tensor<string> -> string
|
||||||
|
string V(const Tensor& tensor) {
|
||||||
|
CHECK_EQ(tensor.dtype(), DT_STRING);
|
||||||
|
CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
|
||||||
|
return tensor.scalar<string>()();
|
||||||
|
}
|
||||||
|
|
||||||
|
string MakeStringKey(const string& name) {
|
||||||
|
return Rendezvous::CreateKey(
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0", 0,
|
||||||
|
"/job:localhost/replica:0/task:0/device:GPU:0", name, FrameAndIter(0, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RendezvousUtilTest, SendBeforeRecv) {
|
||||||
|
// Fire off sends before receive the tensors.
|
||||||
|
Rendezvous::Args args;
|
||||||
|
TF_ASSERT_OK(SendTensorsToRendezvous(
|
||||||
|
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
|
||||||
|
{V("hello1"), V("hello2")}));
|
||||||
|
|
||||||
|
Notification n;
|
||||||
|
std::vector<Tensor> received_keys;
|
||||||
|
RecvOutputsFromRendezvousAsync(
|
||||||
|
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
|
||||||
|
&received_keys, [&n](const Status& status) { n.Notify(); });
|
||||||
|
n.WaitForNotification();
|
||||||
|
|
||||||
|
EXPECT_EQ(2, received_keys.size());
|
||||||
|
EXPECT_EQ("hello1", V(received_keys[0]));
|
||||||
|
EXPECT_EQ("hello2", V(received_keys[1]));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RendezvousUtilTest, RecvBeforeSend) {
|
||||||
|
// Fire off recvs, wait for a notification in the callback.
|
||||||
|
Rendezvous::Args args;
|
||||||
|
|
||||||
|
Notification n;
|
||||||
|
std::vector<Tensor> received_keys;
|
||||||
|
RecvOutputsFromRendezvousAsync(
|
||||||
|
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
|
||||||
|
&received_keys, [&n](const Status& status) { n.Notify(); });
|
||||||
|
|
||||||
|
TF_ASSERT_OK(SendTensorsToRendezvous(
|
||||||
|
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
|
||||||
|
{V("hello1"), V("hello2")}));
|
||||||
|
|
||||||
|
n.WaitForNotification();
|
||||||
|
|
||||||
|
EXPECT_EQ(2, received_keys.size());
|
||||||
|
EXPECT_EQ("hello1", V(received_keys[0]));
|
||||||
|
EXPECT_EQ("hello2", V(received_keys[1]));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/process_util.h"
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||||
#include "tensorflow/core/framework/cancellation.h"
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
@ -321,116 +322,25 @@ Status GraphMgr::DeregisterAll() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GraphMgr::SendInputsToRendezvous(Rendezvous* rendezvous,
|
|
||||||
const NamedTensors& in) {
|
|
||||||
Rendezvous::ParsedKey parsed;
|
|
||||||
for (const auto& p : in) {
|
|
||||||
const string& key = p.first;
|
|
||||||
const Tensor& val = p.second;
|
|
||||||
|
|
||||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
|
||||||
if (s.ok()) {
|
|
||||||
s = rendezvous->Send(parsed, Rendezvous::Args(), val, false);
|
|
||||||
}
|
|
||||||
if (!s.ok()) {
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GraphMgr::RecvOutputsFromRendezvous(Rendezvous* rendezvous,
|
|
||||||
NamedTensors* out) {
|
|
||||||
// Receives values requested by the caller.
|
|
||||||
Rendezvous::ParsedKey parsed;
|
|
||||||
for (auto& p : *out) {
|
|
||||||
const string& key = p.first;
|
|
||||||
Tensor* val = &p.second;
|
|
||||||
bool is_dead = false;
|
|
||||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
|
||||||
if (s.ok()) {
|
|
||||||
s = rendezvous->Recv(parsed, Rendezvous::Args(), val, &is_dead);
|
|
||||||
}
|
|
||||||
if (is_dead) {
|
|
||||||
s = errors::InvalidArgument("The tensor returned for ", key,
|
|
||||||
" was not valid.");
|
|
||||||
}
|
|
||||||
if (!s.ok()) return s;
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
|
|
||||||
NamedTensors* out,
|
|
||||||
const StatusCallback& done) {
|
|
||||||
if (out->empty()) {
|
|
||||||
done(Status::OK());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// We compute the args before calling RecvAsync because we need to ensure that
|
|
||||||
// out isn't being iterated over after done is called, since done deletes out.
|
|
||||||
std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> args;
|
|
||||||
for (auto& p : *out) {
|
|
||||||
Rendezvous::ParsedKey parsed;
|
|
||||||
Status s = Rendezvous::ParseKey(p.first, &parsed);
|
|
||||||
if (!s.ok()) {
|
|
||||||
done(s);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
args.push_back(std::make_tuple(p.first, &p.second, parsed));
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
mutex mu;
|
|
||||||
int done_counter;
|
|
||||||
Status shared_status = Status::OK();
|
|
||||||
} CallState;
|
|
||||||
CallState* call_state = new CallState;
|
|
||||||
call_state->done_counter = out->size();
|
|
||||||
for (auto& p : args) {
|
|
||||||
const string& key = std::get<0>(p);
|
|
||||||
Tensor* val = std::get<1>(p);
|
|
||||||
Rendezvous::ParsedKey parsed = std::get<2>(p);
|
|
||||||
rendezvous->RecvAsync(
|
|
||||||
parsed, Rendezvous::Args(),
|
|
||||||
[val, done, key, call_state](const Status& s,
|
|
||||||
const Rendezvous::Args& send_args,
|
|
||||||
const Rendezvous::Args& recv_args,
|
|
||||||
const Tensor& v, const bool is_dead) {
|
|
||||||
Status status = s;
|
|
||||||
if (status.ok()) {
|
|
||||||
*val = v;
|
|
||||||
if (is_dead) {
|
|
||||||
status = errors::InvalidArgument("The tensor returned for ", key,
|
|
||||||
" was not valid.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
call_state->mu.lock();
|
|
||||||
call_state->shared_status.Update(status);
|
|
||||||
call_state->done_counter--;
|
|
||||||
// If we are the last async call to return, call the done callback.
|
|
||||||
if (call_state->done_counter == 0) {
|
|
||||||
const Status& final_status = call_state->shared_status;
|
|
||||||
call_state->mu.unlock();
|
|
||||||
done(final_status);
|
|
||||||
delete call_state;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
call_state->mu.unlock();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
|
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
|
||||||
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||||
Status s = SendInputsToRendezvous(rendezvous, in);
|
std::vector<string> keys;
|
||||||
|
std::vector<Tensor> tensors_to_send;
|
||||||
|
keys.reserve(in.size());
|
||||||
|
tensors_to_send.reserve(in.size());
|
||||||
|
for (const auto& p : in) {
|
||||||
|
keys.push_back(p.first);
|
||||||
|
tensors_to_send.push_back(p.second);
|
||||||
|
}
|
||||||
|
Status s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys,
|
||||||
|
tensors_to_send);
|
||||||
rendezvous->Unref();
|
rendezvous->Unref();
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
|
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
|
||||||
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||||
Status s = RecvOutputsFromRendezvous(rendezvous, out);
|
Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
|
||||||
rendezvous->Unref();
|
rendezvous->Unref();
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
@ -438,11 +348,24 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
|
|||||||
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
|
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
|
||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||||
RecvOutputsFromRendezvousAsync(rendezvous, out,
|
std::vector<string> keys;
|
||||||
[done, rendezvous](const Status s) {
|
std::vector<Tensor>* received_keys = new std::vector<Tensor>;
|
||||||
rendezvous->Unref();
|
keys.reserve(out->size());
|
||||||
done(s);
|
received_keys->reserve(out->size());
|
||||||
});
|
for (const auto& p : *out) {
|
||||||
|
keys.push_back(p.first);
|
||||||
|
received_keys->push_back(p.second);
|
||||||
|
}
|
||||||
|
RecvOutputsFromRendezvousAsync(
|
||||||
|
rendezvous, Rendezvous::Args(), keys, received_keys,
|
||||||
|
[done, rendezvous, received_keys, out, keys](const Status s) {
|
||||||
|
rendezvous->Unref();
|
||||||
|
for (int i = 0; i < keys.size(); ++i) {
|
||||||
|
(*out)[keys[i]] = (*received_keys)[i];
|
||||||
|
}
|
||||||
|
delete received_keys;
|
||||||
|
done(s);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||||
@ -484,7 +407,16 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
|||||||
|
|
||||||
// Sends values specified by the caller.
|
// Sends values specified by the caller.
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
s = SendInputsToRendezvous(rendezvous, in);
|
std::vector<string> keys;
|
||||||
|
std::vector<Tensor> tensors_to_send;
|
||||||
|
keys.reserve(in.size());
|
||||||
|
tensors_to_send.reserve(in.size());
|
||||||
|
for (auto& p : in) {
|
||||||
|
keys.push_back(p.first);
|
||||||
|
tensors_to_send.push_back(p.second);
|
||||||
|
}
|
||||||
|
s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys,
|
||||||
|
tensors_to_send);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
|
@ -169,11 +169,6 @@ class GraphMgr {
|
|||||||
void BuildCostModel(Item* item, StepStatsCollector* collector,
|
void BuildCostModel(Item* item, StepStatsCollector* collector,
|
||||||
CostGraphDef* cost_graph);
|
CostGraphDef* cost_graph);
|
||||||
|
|
||||||
Status SendInputsToRendezvous(Rendezvous* rendezvous, const NamedTensors& in);
|
|
||||||
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out);
|
|
||||||
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, NamedTensors* out,
|
|
||||||
const StatusCallback& done);
|
|
||||||
|
|
||||||
Status InitItem(const string& session, const GraphDef& gdef,
|
Status InitItem(const string& session, const GraphDef& gdef,
|
||||||
const GraphOptions& graph_options,
|
const GraphOptions& graph_options,
|
||||||
const DebugOptions& debug_options, Item* item);
|
const DebugOptions& debug_options, Item* item);
|
||||||
|
@ -426,6 +426,10 @@ class FunctionLibraryRuntime {
|
|||||||
StepStatsCollector* stats_collector = nullptr;
|
StepStatsCollector* stats_collector = nullptr;
|
||||||
|
|
||||||
std::function<void(std::function<void()>)>* runner = nullptr;
|
std::function<void(std::function<void()>)>* runner = nullptr;
|
||||||
|
|
||||||
|
// Parameters for remote function execution.
|
||||||
|
bool remote_execution = false;
|
||||||
|
string source_device = ""; // Fully specified device name.
|
||||||
};
|
};
|
||||||
typedef std::function<void(const Status&)> DoneCallback;
|
typedef std::function<void(const Status&)> DoneCallback;
|
||||||
virtual void Run(const Options& opts, Handle handle,
|
virtual void Run(const Options& opts, Handle handle,
|
||||||
|
@ -292,7 +292,8 @@ class RemoteCallOp : public AsyncOpKernel {
|
|||||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
|
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
|
||||||
AttrValueMap attr_values = func_->attr();
|
AttrValueMap attr_values = func_->attr();
|
||||||
AttrValue v;
|
AttrValue v;
|
||||||
v.set_s(target->scalar<string>()());
|
const string& target_device = target->scalar<string>()();
|
||||||
|
v.set_s(target_device);
|
||||||
AddAttr("_target", v, &attr_values);
|
AddAttr("_target", v, &attr_values);
|
||||||
|
|
||||||
FunctionLibraryRuntime* lib = ctx->function_library();
|
FunctionLibraryRuntime* lib = ctx->function_library();
|
||||||
@ -310,6 +311,11 @@ class RemoteCallOp : public AsyncOpKernel {
|
|||||||
FunctionLibraryRuntime::Options opts;
|
FunctionLibraryRuntime::Options opts;
|
||||||
opts.step_id = ctx->step_id();
|
opts.step_id = ctx->step_id();
|
||||||
opts.runner = ctx->runner();
|
opts.runner = ctx->runner();
|
||||||
|
opts.source_device = lib->device()->name();
|
||||||
|
if (opts.source_device != target_device) {
|
||||||
|
opts.remote_execution = true;
|
||||||
|
}
|
||||||
|
opts.rendezvous = ctx->rendezvous();
|
||||||
std::vector<Tensor> args;
|
std::vector<Tensor> args;
|
||||||
args.reserve(arguments.size());
|
args.reserve(arguments.size());
|
||||||
for (const Tensor& argument : arguments) {
|
for (const Tensor& argument : arguments) {
|
||||||
@ -334,10 +340,13 @@ class RemoteCallOp : public AsyncOpKernel {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_CPU), RemoteCallOp);
|
REGISTER_KERNEL_BUILDER(
|
||||||
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_GPU), RemoteCallOp);
|
Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
|
||||||
#if TENSORFLOW_USE_SYCL
|
#if TENSORFLOW_USE_SYCL
|
||||||
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_SYCL), RemoteCallOp);
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("RemoteCall").Device(DEVICE_SYCL).HostMemory("target"), RemoteCallOp);
|
||||||
|
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -500,6 +500,54 @@ class FunctionalOpsTest(test.TestCase):
|
|||||||
mul = sess.run(remote_op)
|
mul = sess.run(remote_op)
|
||||||
self.assertEqual(mul, [6])
|
self.assertEqual(mul, [6])
|
||||||
|
|
||||||
|
def testRemoteFunctionCPUGPU(self):
|
||||||
|
if not test_util.is_gpu_available():
|
||||||
|
self.skipTest("No GPU available")
|
||||||
|
|
||||||
|
@function.Defun(dtypes.float32, dtypes.float32)
|
||||||
|
def _remote_fn(a, b):
|
||||||
|
return math_ops.multiply(a, b)
|
||||||
|
|
||||||
|
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
|
||||||
|
a = variables.Variable(2, dtype=dtypes.float32)
|
||||||
|
b = variables.Variable(3, dtype=dtypes.float32)
|
||||||
|
|
||||||
|
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
|
||||||
|
remote_op = functional_ops.remote_call(
|
||||||
|
args=[a, b],
|
||||||
|
Tout=[dtypes.float32],
|
||||||
|
f=_remote_fn,
|
||||||
|
target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
mul = sess.run(remote_op)
|
||||||
|
self.assertEqual(mul, 9.0)
|
||||||
|
|
||||||
|
def testRemoteFunctionGPUCPU(self):
|
||||||
|
if not test_util.is_gpu_available():
|
||||||
|
self.skipTest("No GPU available")
|
||||||
|
|
||||||
|
@function.Defun(dtypes.float32, dtypes.float32)
|
||||||
|
def _remote_fn(a, b):
|
||||||
|
return math_ops.multiply(a, b)
|
||||||
|
|
||||||
|
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
|
||||||
|
a = variables.Variable(2, dtype=dtypes.float32)
|
||||||
|
b = variables.Variable(3, dtype=dtypes.float32)
|
||||||
|
|
||||||
|
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
|
||||||
|
remote_op = functional_ops.remote_call(
|
||||||
|
args=[a, b],
|
||||||
|
Tout=[dtypes.float32],
|
||||||
|
f=_remote_fn,
|
||||||
|
target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
mul = sess.run(remote_op)
|
||||||
|
self.assertEqual(mul, 9.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user