diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 2f93c345027..6891fd4231f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -24,6 +24,8 @@ py_test( "//tensorflow/python:functional_ops", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:script_ops", "//tensorflow/python:training", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index b20742f7758..7ee21d4e01d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -27,10 +27,13 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import script_ops from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -420,7 +423,7 @@ class IteratorTest(test.TestCase): def testRemoteIteratorUsingRemoteCallOpDirectSession(self): worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config.device_count["CPU"] = 3 with ops.device("/job:localhost/replica:0/task:0/cpu:1"): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) @@ -448,12 +451,12 @@ class IteratorTest(test.TestCase): target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [1]) - # Fails when target is cpu:0 where the resource is not located. + # Fails when target is cpu:2 where the resource is not located. with self.assertRaises(errors.InvalidArgumentError): sess.run( remote_op, feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" + target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" }) elem = sess.run( remote_op, @@ -474,6 +477,61 @@ class IteratorTest(test.TestCase): target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) + def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + with ops.device("/job:localhost/replica:0/task:0/cpu:0"): + dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) + iterator_3 = dataset_3.make_one_shot_iterator() + iterator_3_handle = iterator_3.string_handle() + + def _encode_raw(byte_array): + return "".join([chr(item) for item in byte_array]) + + @function.Defun(dtypes.uint8) + def _remote_fn(h): + handle = script_ops.py_func(_encode_raw, [h], dtypes.string) + remote_iterator = dataset_ops.Iterator.from_string_handle( + handle, dataset_3.output_types, dataset_3.output_shapes) + return remote_iterator.get_next() + + with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): + target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + iterator_3_handle_uint8 = parsing_ops.decode_raw( + bytes=iterator_3_handle, out_type=dtypes.uint8) + remote_op = functional_ops.remote_call( + args=[iterator_3_handle_uint8], + Tout=[dtypes.int32], + f=_remote_fn, + target=target_placeholder) + + with self.test_session() as sess: + elem = sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" + }) + self.assertEqual(elem, [1]) + elem = sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" + }) + self.assertEqual(elem, [2]) + elem = sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" + }) + self.assertEqual(elem, [3]) + with self.assertRaises(errors.OutOfRangeError): + sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" + }) + if __name__ == "__main__": test.main() diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 4b80d2c543e..03ea115c239 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1785,6 +1785,7 @@ tf_cuda_library( "common_runtime/process_util.cc", "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", + "common_runtime/rendezvous_util.cc", "common_runtime/resource_variable_read_optimizer.cc", "common_runtime/session.cc", "common_runtime/session_factory.cc", @@ -1828,6 +1829,7 @@ tf_cuda_library( "common_runtime/profile_handler.h", "common_runtime/renamed_device.h", "common_runtime/rendezvous_mgr.h", + "common_runtime/rendezvous_util.h", "common_runtime/session_factory.h", "common_runtime/graph_execution_state.h", "common_runtime/placer.h", @@ -2669,29 +2671,29 @@ tf_cc_test( srcs = ["common_runtime/process_function_library_runtime_test.cc"], linkstatic = tf_kernel_tests_linkstatic(), deps = [ - ":core", ":core_cpu", ":core_cpu_internal", - ":direct_session_internal", ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", ":test", ":test_main", ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", - "//tensorflow/cc:functional_ops", "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:matmul_op", - "//tensorflow/core/kernels:shape_ops", - "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "common_runtime_rendezvous_util_test", + size = "small", + srcs = ["common_runtime/rendezvous_util_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core_cpu_internal", + ":lib", + ":test", + ":test_main", ], ) diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 4b239606a84..4aeacc6d612 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -213,6 +213,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { FunctionBody** g_body); bool IsLocalTarget(const AttrSlice& attrs); AttrValueMap FixAttrs(const AttrSlice& attrs); + void RunRemote(const Options& opts, Handle handle, + gtl::ArraySlice args, std::vector* rets, + Executor::Args* exec_args, Item* item, DoneCallback done); TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl); }; @@ -557,52 +560,130 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { return Status::OK(); } +void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, + gtl::ArraySlice args, + std::vector* rets, + Executor::Args* exec_args, + Item* item, DoneCallback done) { + FunctionCallFrame* frame = exec_args->call_frame; + string target_device = parent_->GetDeviceName(handle); + string source_device = opts.source_device; + Rendezvous* rendezvous = opts.rendezvous; + // TODO(rohanj): Handle alloc_attrs in Rendezvous::Args. + Rendezvous::Args rendez_args; + Status s = + parent_->GetDeviceContext(target_device, &rendez_args.device_context); + if (!s.ok()) { + delete frame; + delete exec_args; + done(s); + return; + } + + // The ProcFLR sends the arguments to the function from the source_device to + // the target_device. So here we receive those arguments. Similarly, when the + // computation is done and stored in *rets, we send the return values back + // to the source_device (caller) so that the ProcFLR can receive them later. + std::vector* remote_args = new std::vector; + ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( + source_device, target_device, "arg_", args.size(), rendez_args, + rendezvous, remote_args, + [frame, remote_args, item, source_device, target_device, rendezvous, + rendez_args, rets, done, exec_args](const Status& status) { + Status s = status; + s = frame->SetArgs(*remote_args); + if (!s.ok()) { + delete frame; + delete remote_args; + delete exec_args; + done(s); + return; + } + item->exec->RunAsync( + *exec_args, + [item, frame, rets, done, source_device, target_device, rendezvous, + rendez_args, remote_args, exec_args](const Status& status) { + item->Unref(); + Status s = status; + if (s.ok()) { + s = frame->ConsumeRetvals(rets); + } + delete frame; + if (!s.ok()) { + delete remote_args; + delete exec_args; + done(s); + return; + } + s = ProcessFunctionLibraryRuntime::SendTensors( + target_device, source_device, "ret_", *rets, rendez_args, + rendezvous); + delete remote_args; + delete exec_args; + done(s); + }); + }); +} + void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, gtl::ArraySlice args, std::vector* rets, DoneCallback done) { if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { - return done(errors::Cancelled("")); + done(errors::Cancelled("")); + return; } if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { - return parent_->Run(opts, handle, args, rets, done); + parent_->Run(opts, handle, args, rets, done); + return; } const FunctionBody* fbody = GetFunctionBody(handle); FunctionCallFrame* frame = new FunctionCallFrame(fbody->arg_types, fbody->ret_types); - Status s = frame->SetArgs(args); - if (!s.ok()) { - delete frame; - return done(s); - } + Item* item = nullptr; - s = GetOrCreateItem(handle, &item); + Status s = GetOrCreateItem(handle, &item); if (!s.ok()) { delete frame; - return done(s); + done(s); + return; } DCHECK(opts.runner != nullptr); - Executor::Args exec_args; + Executor::Args* exec_args = new Executor::Args; // Inherit the step_id from the caller. - exec_args.step_id = opts.step_id; - exec_args.rendezvous = opts.rendezvous; - exec_args.stats_collector = opts.stats_collector; - exec_args.call_frame = frame; - exec_args.cancellation_manager = opts.cancellation_manager; - exec_args.step_container = opts.step_container; - exec_args.runner = *opts.runner; + exec_args->step_id = opts.step_id; + exec_args->rendezvous = opts.rendezvous; + exec_args->stats_collector = opts.stats_collector; + exec_args->call_frame = frame; + exec_args->cancellation_manager = opts.cancellation_manager; + exec_args->step_container = opts.step_container; + exec_args->runner = *opts.runner; + + if (opts.remote_execution) { + RunRemote(opts, handle, args, rets, exec_args, item, done); + return; + } + + s = frame->SetArgs(args); + if (!s.ok()) { + delete frame; + delete exec_args; + done(s); + return; + } item->exec->RunAsync( // Executor args - exec_args, + *exec_args, // Done callback. - [item, frame, rets, done](const Status& status) { + [item, frame, rets, done, exec_args](const Status& status) { item->Unref(); Status s = status; if (s.ok()) { - s = frame->GetRetvals(rets); + s = frame->ConsumeRetvals(rets); } delete frame; + delete exec_args; done(s); }); } diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index a9f06c4df03..7eac1674e71 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function_testlib.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op.h" @@ -155,6 +156,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle, + FunctionLibraryRuntime::Options opts, const std::vector& args, std::vector rets) { std::atomic call_count(0); std::function)> runner = @@ -164,7 +166,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { }; Notification done; - FunctionLibraryRuntime::Options opts; opts.runner = &runner; std::vector out; Status status; @@ -205,7 +206,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { if (!status.ok()) { return status; } - return Run(flr, handle, args, std::move(rets)); + FunctionLibraryRuntime::Options opts; + return Run(flr, handle, opts, args, std::move(rets)); } std::unique_ptr GetFuncBody(FunctionLibraryRuntime* flr, @@ -963,15 +965,21 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { {{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle)); Tensor y; + FunctionLibraryRuntime::Options opts; + opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get()); + opts.source_device = "/device:CPU:1"; // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1. - TF_CHECK_OK(Run(flr1_, handle, {}, {&y})); + TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:localhost/replica:0/task:0/cpu:1"}, TensorShape({}))); - TF_CHECK_OK(Run(flr2_, handle, {}, {&y})); + opts.remote_execution = true; + opts.source_device = "/job:localhost/replica:0/task:0/cpu:2"; + TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:localhost/replica:0/task:0/cpu:1"}, TensorShape({}))); + opts.rendezvous->Unref(); } namespace { diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 0caec036252..c39bab2348e 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/rendezvous_util.h" #include "tensorflow/core/lib/gtl/map_util.h" namespace tensorflow { @@ -57,6 +58,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( } } +/* static */ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget( const AttrSlice& attrs) { const AttrValue* value; @@ -66,6 +68,63 @@ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget( return value->s(); } +/* static */ +Status ProcessFunctionLibraryRuntime::SendTensors( + const string& source_device, const string& target_device, + const string& key_prefix, gtl::ArraySlice tensors_to_send, + const Rendezvous::Args& args, Rendezvous* rendezvous) { + std::vector keys; + for (int i = 0; i < tensors_to_send.size(); ++i) { + string name = strings::StrCat(key_prefix, i); + string key = Rendezvous::CreateKey(source_device, i, target_device, name, + FrameAndIter(0, 0)); + keys.push_back(key); + } + TF_RETURN_IF_ERROR( + SendTensorsToRendezvous(rendezvous, args, keys, tensors_to_send)); + return Status::OK(); +} + +/* static */ +void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( + const string& source_device, const string& target_device, + const string& key_prefix, int64 num_tensors, const Rendezvous::Args& args, + Rendezvous* rendezvous, std::vector* received_tensors, + const StatusCallback& done) { + std::vector keys; + for (int64 i = 0; i < num_tensors; ++i) { + string name = strings::StrCat(key_prefix, i); + string key = Rendezvous::CreateKey(source_device, i, target_device, name, + FrameAndIter(0, 0)); + keys.push_back(key); + } + RecvOutputsFromRendezvousAsync( + rendezvous, args, keys, received_tensors, + [done](const Status& status) { done(status); }); +} + +Status ProcessFunctionLibraryRuntime::GetDeviceContext( + const string& device_name, DeviceContext** device_context) { + *device_context = nullptr; + FunctionLibraryRuntime* flr = GetFLR(device_name); + if (flr == nullptr) { + return errors::InvalidArgument("Device name: ", device_name, " not found."); + } + Device* device = flr->device(); + string device_type = device->parsed_name().type; + if (device_type == "CPU") return Status::OK(); + if (device_type == "GPU") { + auto* dev_info = flr->device()->tensorflow_gpu_device_info(); + if (dev_info) { + *device_context = dev_info->default_context; + return Status::OK(); + } + } + return errors::Internal("Device type: ", device_type, + " is currently unsupported for remote ", + "function executions"); +} + FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( const string& device_name) { if (flr_map_.find(device_name) == flr_map_.end()) { @@ -105,6 +164,7 @@ FunctionLibraryRuntime::LocalHandle ProcessFunctionLibraryRuntime::GetHandleOnDevice( const string& device_name, FunctionLibraryRuntime::Handle handle) { mutex_lock l(mu_); + CHECK_LE(handle, function_data_.size()); std::pair p = function_data_[handle]; if (p.first != device_name) { @@ -113,6 +173,15 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice( return p.second; } +string ProcessFunctionLibraryRuntime::GetDeviceName( + FunctionLibraryRuntime::Handle handle) { + mutex_lock l(mu_); + CHECK_LE(handle, function_data_.size()); + std::pair p = + function_data_[handle]; + return p.first; +} + Status ProcessFunctionLibraryRuntime::Instantiate( const string& function_name, AttrSlice attrs, FunctionLibraryRuntime::Handle* handle) { @@ -129,15 +198,58 @@ void ProcessFunctionLibraryRuntime::Run( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle handle, gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) { + if (!opts.remote_execution) { + done(errors::InvalidArgument( + "ProcessFunctionLibraryRuntime::Run should only be called when there ", + "is a remote execution.")); + return; + } + FunctionLibraryRuntime* flr = nullptr; + string target_device; { mutex_lock l(mu_); + CHECK_LE(handle, function_data_.size()); std::pair p = function_data_[handle]; + target_device = p.first; flr = GetFLR(p.first); } if (flr != nullptr) { - return flr->Run(opts, handle, args, rets, std::move(done)); + auto rendezvous = opts.rendezvous; + string source_device = opts.source_device; + Rendezvous::Args rendez_args; + Status s = GetDeviceContext(source_device, &rendez_args.device_context); + if (!s.ok()) { + done(s); + return; + } + // Send the args over to the target device. + s = SendTensors(source_device, target_device, "arg_", args, rendez_args, + rendezvous); + if (!s.ok()) { + done(s); + return; + } + std::vector* remote_rets = new std::vector; + flr->Run(opts, handle, args, remote_rets, + [source_device, target_device, rendezvous, remote_rets, rets, done, + rendez_args](const Status& status) { + if (!status.ok()) { + delete remote_rets; + done(status); + return; + } + int64 num_returns = remote_rets->size(); + delete remote_rets; + // Now receive the return values from the target. + ReceiveTensorsAsync(target_device, source_device, "ret_", + num_returns, rendez_args, rendezvous, rets, + done); + }); + } else { + done(errors::Internal("Could not find device")); + return; } } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 2259997005e..2e97bae4b4f 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -45,6 +45,31 @@ class ProcessFunctionLibraryRuntime { // attribute, returns "". Canonicalizes the device name. static string ObtainFunctionTarget(const AttrSlice& attrs); + // Sends `tensors_to_send` from `source_device` to `target_device` using + // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the + // Rendezvous. Method takes references on each of the `tensors_to_send`. + // Method doesn't block. + static Status SendTensors(const string& source_device, + const string& target_device, + const string& key_prefix, + gtl::ArraySlice tensors_to_send, + const Rendezvous::Args& args, + Rendezvous* rendezvous); + + typedef std::function StatusCallback; + + // Receives `received_tensors` from `target_device` (originally sent from + // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the + // keys to be retrieved. Method doesn't block and calls `done` when + // `num_tensors` are fetched. + static void ReceiveTensorsAsync(const string& source_device, + const string& target_device, + const string& key_prefix, int64 num_tensors, + const Rendezvous::Args& args, + Rendezvous* rendezvous, + std::vector* received_tensors, + const StatusCallback& done); + static const char kDefaultFLRDevice[]; // Returns the FunctionLibraryRuntime for the corresponding device_name. FunctionLibraryRuntime* GetFLR(const string& device_name); @@ -85,6 +110,17 @@ class ProcessFunctionLibraryRuntime { FunctionLibraryRuntime::DoneCallback done); private: + // For a given device_name, returns a DeviceContext for copying + // tensors to/from the device. + Status GetDeviceContext(const string& device_name, + DeviceContext** device_context); + + // Looks up the information for the given `handle` and returns the name + // of the device where the function is registered. + string GetDeviceName(FunctionLibraryRuntime::Handle handle); + + friend class FunctionLibraryRuntimeImpl; + mutable mutex mu_; // Holds all the function invocations here. diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 1536aedde58..fdbab46f547 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function_testlib.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" @@ -43,10 +44,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { proc_flr_.reset(new ProcessFunctionLibraryRuntime( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), opts)); + rendezvous_ = new IntraProcessRendezvous(device_mgr_.get()); } - Status Run(const string& name, test::function::Attrs attrs, - const std::vector& args, std::vector rets) { + Status Run(const string& name, FunctionLibraryRuntime::Options opts, + test::function::Attrs attrs, const std::vector& args, + std::vector rets) { FunctionLibraryRuntime::Handle handle; Status status = proc_flr_->Instantiate(name, attrs, &handle); if (!status.ok()) { @@ -61,7 +64,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { }; Notification done; - FunctionLibraryRuntime::Options opts; opts.runner = &runner; std::vector out; proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) { @@ -86,6 +88,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { std::unique_ptr device_mgr_; std::unique_ptr lib_def_; std::unique_ptr proc_flr_; + IntraProcessRendezvous* rendezvous_; }; TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) { @@ -99,6 +102,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) { EXPECT_EQ(flr->device(), devices_[1]); flr = proc_flr_->GetFLR("abc"); EXPECT_EQ(flr, nullptr); + rendezvous_->Unref(); } TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) { @@ -118,69 +122,94 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) { TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) { Init({test::function::XTimesTwo()}); + FunctionLibraryRuntime::Options opts; + opts.source_device = "/job:a/replica:0/task:0/cpu:0"; + opts.rendezvous = rendezvous_; + opts.remote_execution = true; auto x = test::AsTensor({1, 2, 3, 4}); Tensor y; TF_CHECK_OK( - Run("XTimesTwo", + Run("XTimesTwo", opts, {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, {&y})); test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); + rendezvous_->Unref(); } TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) { Init({test::function::FindDevice()}); + FunctionLibraryRuntime::Options opts; + opts.source_device = "/job:a/replica:0/task:0/cpu:0"; + opts.rendezvous = rendezvous_; + opts.remote_execution = true; Tensor y; - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, - {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:a/replica:0/task:0/cpu:0"}, TensorShape({}))); + rendezvous_->Unref(); } TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) { Init({test::function::XTimesTwo(), test::function::XTimesFour()}); auto x = test::AsTensor({1, 2, 3, 4}); + FunctionLibraryRuntime::Options opts; + opts.source_device = "/job:a/replica:0/task:0/cpu:0"; + opts.rendezvous = rendezvous_; + opts.remote_execution = true; Tensor y; TF_CHECK_OK( - Run("XTimesTwo", + Run("XTimesTwo", opts, {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, {&y})); test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); TF_CHECK_OK( - Run("XTimesFour", + Run("XTimesFour", opts, {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, {&y})); test::ExpectTensorEqual(y, test::AsTensor({4, 8, 12, 16})); + rendezvous_->Unref(); } TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) { Init({test::function::FindDevice()}); + FunctionLibraryRuntime::Options opts; + opts.source_device = "/job:a/replica:0/task:0/cpu:0"; + opts.rendezvous = rendezvous_; + opts.remote_execution = true; Tensor y; - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, - {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:a/replica:0/task:0/cpu:1"}, TensorShape({}))); - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, - {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:a/replica:0/task:0/cpu:1"}, TensorShape({}))); + rendezvous_->Unref(); } TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) { Init({test::function::FindDevice()}); + FunctionLibraryRuntime::Options opts; + opts.source_device = "/job:a/replica:0/task:0/cpu:0"; + opts.rendezvous = rendezvous_; + opts.remote_execution = true; Tensor y; - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, - {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:a/replica:0/task:0/cpu:0"}, TensorShape({}))); - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, - {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:a/replica:0/task:0/cpu:1"}, TensorShape({}))); + rendezvous_->Unref(); } } // anonymous namespace diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc new file mode 100644 index 00000000000..a0d409e7735 --- /dev/null +++ b/tensorflow/core/common_runtime/rendezvous_util.cc @@ -0,0 +1,119 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/rendezvous_util.h" + +namespace tensorflow { + +Status SendTensorsToRendezvous(Rendezvous* rendezvous, + const Rendezvous::Args& args, + const std::vector& keys, + gtl::ArraySlice tensors_to_send) { + if (keys.size() != tensors_to_send.size()) { + return errors::InvalidArgument( + "keys and tensors_to_send are not the same size. keys.size() = ", + keys.size(), "; tensors_to_send.size() = ", tensors_to_send.size()); + } + Rendezvous::ParsedKey parsed; + for (int i = 0; i < keys.size(); ++i) { + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(keys[i], &parsed)); + TF_RETURN_IF_ERROR( + rendezvous->Send(parsed, args, tensors_to_send[i], false)); + } + return Status::OK(); +} + +void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, + const Rendezvous::Args& args, + const std::vector& keys, + std::vector* received_tensors, + const StatusCallback& done) { + if (keys.empty()) { + done(Status::OK()); + return; + } + received_tensors->reserve(keys.size()); + std::vector> arguments; + for (int i = 0; i < keys.size(); ++i) { + Rendezvous::ParsedKey parsed; + Status s = Rendezvous::ParseKey(keys[i], &parsed); + received_tensors->push_back(Tensor()); + if (!s.ok()) { + done(s); + return; + } + arguments.push_back( + std::make_tuple(keys[i], &((*received_tensors)[i]), parsed)); + } + + typedef struct { + mutex mu; + int64 done_counter; + Status shared_status = Status::OK(); + } CallState; + CallState* call_state = new CallState; + call_state->done_counter = keys.size(); + for (auto& p : arguments) { + const string& key = std::get<0>(p); + Tensor* val = std::get<1>(p); + Rendezvous::ParsedKey parsed = std::get<2>(p); + rendezvous->RecvAsync( + parsed, args, + [val, done, key, call_state](const Status& s, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& v, const bool is_dead) { + Status status = s; + if (status.ok()) { + *val = v; + if (is_dead) { + status = errors::InvalidArgument("The tensor returned for ", key, + " was not valid."); + } + } + call_state->mu.lock(); + call_state->shared_status.Update(status); + call_state->done_counter--; + // If we are the last async call to return, call the done callback. + if (call_state->done_counter == 0) { + const Status& final_status = call_state->shared_status; + call_state->mu.unlock(); + done(final_status); + delete call_state; + return; + } + call_state->mu.unlock(); + }); + } +} + +Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out, + const Rendezvous::Args& args) { + // Receives values requested by the caller. + Rendezvous::ParsedKey parsed; + for (auto& p : *out) { + const string& key = p.first; + Tensor* val = &p.second; + bool is_dead = false; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed)); + TF_RETURN_IF_ERROR(rendezvous->Recv(parsed, args, val, &is_dead)); + if (is_dead) { + return errors::InvalidArgument("The tensor returned for ", key, + " was not valid."); + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/rendezvous_util.h b/tensorflow/core/common_runtime/rendezvous_util.h new file mode 100644 index 00000000000..a54f8c3f948 --- /dev/null +++ b/tensorflow/core/common_runtime/rendezvous_util.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ + +#include + +#include "tensorflow/core/framework/rendezvous.h" + +namespace tensorflow { + +typedef std::map NamedTensors; +typedef std::function StatusCallback; + +// Uses `rendezvous` to send tensors in `in`. +Status SendTensorsToRendezvous(Rendezvous* rendezvous, + const Rendezvous::Args& args, + const std::vector& keys, + gtl::ArraySlice tensors_to_send); + +void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, + const Rendezvous::Args& args, + const std::vector& keys, + std::vector* received_tensors, + const StatusCallback& done); + +Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out, + const Rendezvous::Args& args); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ diff --git a/tensorflow/core/common_runtime/rendezvous_util_test.cc b/tensorflow/core/common_runtime/rendezvous_util_test.cc new file mode 100644 index 00000000000..8ee9f4d5226 --- /dev/null +++ b/tensorflow/core/common_runtime/rendezvous_util_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/rendezvous_util.h" + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class RendezvousUtilTest : public ::testing::Test { + public: + RendezvousUtilTest() { rendez_ = NewLocalRendezvous(); } + + ~RendezvousUtilTest() override { rendez_->Unref(); } + + Rendezvous* rendez_; +}; + +// string -> Tensor +Tensor V(const string& content) { + Tensor tensor(DT_STRING, TensorShape({})); + tensor.scalar()() = content; + return tensor; +} + +// Tensor -> string +string V(const Tensor& tensor) { + CHECK_EQ(tensor.dtype(), DT_STRING); + CHECK(TensorShapeUtils::IsScalar(tensor.shape())); + return tensor.scalar()(); +} + +string MakeStringKey(const string& name) { + return Rendezvous::CreateKey( + "/job:localhost/replica:0/task:0/device:CPU:0", 0, + "/job:localhost/replica:0/task:0/device:GPU:0", name, FrameAndIter(0, 0)); +} + +TEST_F(RendezvousUtilTest, SendBeforeRecv) { + // Fire off sends before receive the tensors. + Rendezvous::Args args; + TF_ASSERT_OK(SendTensorsToRendezvous( + rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")}, + {V("hello1"), V("hello2")})); + + Notification n; + std::vector received_keys; + RecvOutputsFromRendezvousAsync( + rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")}, + &received_keys, [&n](const Status& status) { n.Notify(); }); + n.WaitForNotification(); + + EXPECT_EQ(2, received_keys.size()); + EXPECT_EQ("hello1", V(received_keys[0])); + EXPECT_EQ("hello2", V(received_keys[1])); +} + +TEST_F(RendezvousUtilTest, RecvBeforeSend) { + // Fire off recvs, wait for a notification in the callback. + Rendezvous::Args args; + + Notification n; + std::vector received_keys; + RecvOutputsFromRendezvousAsync( + rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")}, + &received_keys, [&n](const Status& status) { n.Notify(); }); + + TF_ASSERT_OK(SendTensorsToRendezvous( + rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")}, + {V("hello1"), V("hello2")})); + + n.WaitForNotification(); + + EXPECT_EQ(2, received_keys.size()); + EXPECT_EQ("hello1", V(received_keys[0])); + EXPECT_EQ("hello2", V(received_keys[1])); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 1169b86c9db..411b6d861b7 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/common_runtime/rendezvous_util.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/framework/cancellation.h" @@ -321,116 +322,25 @@ Status GraphMgr::DeregisterAll() { return Status::OK(); } -Status GraphMgr::SendInputsToRendezvous(Rendezvous* rendezvous, - const NamedTensors& in) { - Rendezvous::ParsedKey parsed; - for (const auto& p : in) { - const string& key = p.first; - const Tensor& val = p.second; - - Status s = Rendezvous::ParseKey(key, &parsed); - if (s.ok()) { - s = rendezvous->Send(parsed, Rendezvous::Args(), val, false); - } - if (!s.ok()) { - return s; - } - } - return Status::OK(); -} - -Status GraphMgr::RecvOutputsFromRendezvous(Rendezvous* rendezvous, - NamedTensors* out) { - // Receives values requested by the caller. - Rendezvous::ParsedKey parsed; - for (auto& p : *out) { - const string& key = p.first; - Tensor* val = &p.second; - bool is_dead = false; - Status s = Rendezvous::ParseKey(key, &parsed); - if (s.ok()) { - s = rendezvous->Recv(parsed, Rendezvous::Args(), val, &is_dead); - } - if (is_dead) { - s = errors::InvalidArgument("The tensor returned for ", key, - " was not valid."); - } - if (!s.ok()) return s; - } - return Status::OK(); -} - -void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, - NamedTensors* out, - const StatusCallback& done) { - if (out->empty()) { - done(Status::OK()); - return; - } - // We compute the args before calling RecvAsync because we need to ensure that - // out isn't being iterated over after done is called, since done deletes out. - std::vector> args; - for (auto& p : *out) { - Rendezvous::ParsedKey parsed; - Status s = Rendezvous::ParseKey(p.first, &parsed); - if (!s.ok()) { - done(s); - return; - } - args.push_back(std::make_tuple(p.first, &p.second, parsed)); - } - - typedef struct { - mutex mu; - int done_counter; - Status shared_status = Status::OK(); - } CallState; - CallState* call_state = new CallState; - call_state->done_counter = out->size(); - for (auto& p : args) { - const string& key = std::get<0>(p); - Tensor* val = std::get<1>(p); - Rendezvous::ParsedKey parsed = std::get<2>(p); - rendezvous->RecvAsync( - parsed, Rendezvous::Args(), - [val, done, key, call_state](const Status& s, - const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, - const Tensor& v, const bool is_dead) { - Status status = s; - if (status.ok()) { - *val = v; - if (is_dead) { - status = errors::InvalidArgument("The tensor returned for ", key, - " was not valid."); - } - } - call_state->mu.lock(); - call_state->shared_status.Update(status); - call_state->done_counter--; - // If we are the last async call to return, call the done callback. - if (call_state->done_counter == 0) { - const Status& final_status = call_state->shared_status; - call_state->mu.unlock(); - done(final_status); - delete call_state; - return; - } - call_state->mu.unlock(); - }); - } -} - Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) { Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); - Status s = SendInputsToRendezvous(rendezvous, in); + std::vector keys; + std::vector tensors_to_send; + keys.reserve(in.size()); + tensors_to_send.reserve(in.size()); + for (const auto& p : in) { + keys.push_back(p.first); + tensors_to_send.push_back(p.second); + } + Status s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys, + tensors_to_send); rendezvous->Unref(); return s; } Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); - Status s = RecvOutputsFromRendezvous(rendezvous, out); + Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args()); rendezvous->Unref(); return s; } @@ -438,11 +348,24 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, StatusCallback done) { Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); - RecvOutputsFromRendezvousAsync(rendezvous, out, - [done, rendezvous](const Status s) { - rendezvous->Unref(); - done(s); - }); + std::vector keys; + std::vector* received_keys = new std::vector; + keys.reserve(out->size()); + received_keys->reserve(out->size()); + for (const auto& p : *out) { + keys.push_back(p.first); + received_keys->push_back(p.second); + } + RecvOutputsFromRendezvousAsync( + rendezvous, Rendezvous::Args(), keys, received_keys, + [done, rendezvous, received_keys, out, keys](const Status s) { + rendezvous->Unref(); + for (int i = 0; i < keys.size(); ++i) { + (*out)[keys[i]] = (*received_keys)[i]; + } + delete received_keys; + done(s); + }); } void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, @@ -484,7 +407,16 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, // Sends values specified by the caller. if (s.ok()) { - s = SendInputsToRendezvous(rendezvous, in); + std::vector keys; + std::vector tensors_to_send; + keys.reserve(in.size()); + tensors_to_send.reserve(in.size()); + for (auto& p : in) { + keys.push_back(p.first); + tensors_to_send.push_back(p.second); + } + s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys, + tensors_to_send); } if (!s.ok()) { diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index d719dd4ec6b..c6f55e4ef9c 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -169,11 +169,6 @@ class GraphMgr { void BuildCostModel(Item* item, StepStatsCollector* collector, CostGraphDef* cost_graph); - Status SendInputsToRendezvous(Rendezvous* rendezvous, const NamedTensors& in); - Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out); - void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, NamedTensors* out, - const StatusCallback& done); - Status InitItem(const string& session, const GraphDef& gdef, const GraphOptions& graph_options, const DebugOptions& debug_options, Item* item); diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 317707644b3..e3842ea58d3 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -426,6 +426,10 @@ class FunctionLibraryRuntime { StepStatsCollector* stats_collector = nullptr; std::function)>* runner = nullptr; + + // Parameters for remote function execution. + bool remote_execution = false; + string source_device = ""; // Fully specified device name. }; typedef std::function DoneCallback; virtual void Run(const Options& opts, Handle handle, diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index a1dfd4c3d31..629e29958f6 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -292,7 +292,8 @@ class RemoteCallOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); AttrValueMap attr_values = func_->attr(); AttrValue v; - v.set_s(target->scalar()()); + const string& target_device = target->scalar()(); + v.set_s(target_device); AddAttr("_target", v, &attr_values); FunctionLibraryRuntime* lib = ctx->function_library(); @@ -310,6 +311,11 @@ class RemoteCallOp : public AsyncOpKernel { FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); opts.runner = ctx->runner(); + opts.source_device = lib->device()->name(); + if (opts.source_device != target_device) { + opts.remote_execution = true; + } + opts.rendezvous = ctx->rendezvous(); std::vector args; args.reserve(arguments.size()); for (const Tensor& argument : arguments) { @@ -334,10 +340,13 @@ class RemoteCallOp : public AsyncOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp); }; -REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_CPU), RemoteCallOp); -REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_GPU), RemoteCallOp); +REGISTER_KERNEL_BUILDER( + Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp); +REGISTER_KERNEL_BUILDER( + Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp); #if TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_SYCL), RemoteCallOp); +REGISTER_KERNEL_BUILDER( + Name("RemoteCall").Device(DEVICE_SYCL).HostMemory("target"), RemoteCallOp); #endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index a7bedc7199c..9ee7c0c5611 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -500,6 +500,54 @@ class FunctionalOpsTest(test.TestCase): mul = sess.run(remote_op) self.assertEqual(mul, [6]) + def testRemoteFunctionCPUGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + @function.Defun(dtypes.float32, dtypes.float32) + def _remote_fn(a, b): + return math_ops.multiply(a, b) + + with ops.device("/job:localhost/replica:0/task:0/cpu:0"): + a = variables.Variable(2, dtype=dtypes.float32) + b = variables.Variable(3, dtype=dtypes.float32) + + with ops.device("/job:localhost/replica:0/task:0/cpu:0"): + remote_op = functional_ops.remote_call( + args=[a, b], + Tout=[dtypes.float32], + f=_remote_fn, + target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0 + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + mul = sess.run(remote_op) + self.assertEqual(mul, 9.0) + + def testRemoteFunctionGPUCPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + @function.Defun(dtypes.float32, dtypes.float32) + def _remote_fn(a, b): + return math_ops.multiply(a, b) + + with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): + a = variables.Variable(2, dtype=dtypes.float32) + b = variables.Variable(3, dtype=dtypes.float32) + + with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): + remote_op = functional_ops.remote_call( + args=[a, b], + Tout=[dtypes.float32], + f=_remote_fn, + target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0 + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + mul = sess.run(remote_op) + self.assertEqual(mul, 9.0) + if __name__ == "__main__": test.main()