Using rendezvous manager to pass args / rets between devices during function remote execution. This enables CPU->GPU remote device executions now.

PiperOrigin-RevId: 168038285
This commit is contained in:
Rohan Jain 2017-09-08 13:30:17 -07:00 committed by TensorFlower Gardener
parent 82cc6529f4
commit 450c3b5626
16 changed files with 747 additions and 174 deletions

View File

@ -24,6 +24,8 @@ py_test(
"//tensorflow/python:functional_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:training",
"//third_party/py/numpy",
],

View File

@ -27,10 +27,13 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@ -420,7 +423,7 @@ class IteratorTest(test.TestCase):
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 2
worker_config.device_count["CPU"] = 3
with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
@ -448,12 +451,12 @@ class IteratorTest(test.TestCase):
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
})
self.assertEqual(elem, [1])
# Fails when target is cpu:0 where the resource is not located.
# Fails when target is cpu:2 where the resource is not located.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
target_placeholder: "/job:localhost/replica:0/task:0/cpu:2"
})
elem = sess.run(
remote_op,
@ -474,6 +477,61 @@ class IteratorTest(test.TestCase):
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
})
def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
iterator_3 = dataset_3.make_one_shot_iterator()
iterator_3_handle = iterator_3.string_handle()
def _encode_raw(byte_array):
return "".join([chr(item) for item in byte_array])
@function.Defun(dtypes.uint8)
def _remote_fn(h):
handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
remote_iterator = dataset_ops.Iterator.from_string_handle(
handle, dataset_3.output_types, dataset_3.output_shapes)
return remote_iterator.get_next()
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
iterator_3_handle_uint8 = parsing_ops.decode_raw(
bytes=iterator_3_handle, out_type=dtypes.uint8)
remote_op = functional_ops.remote_call(
args=[iterator_3_handle_uint8],
Tout=[dtypes.int32],
f=_remote_fn,
target=target_placeholder)
with self.test_session() as sess:
elem = sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
self.assertEqual(elem, [1])
elem = sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
self.assertEqual(elem, [2])
elem = sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
self.assertEqual(elem, [3])
with self.assertRaises(errors.OutOfRangeError):
sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
if __name__ == "__main__":
test.main()

View File

@ -1785,6 +1785,7 @@ tf_cuda_library(
"common_runtime/process_util.cc",
"common_runtime/renamed_device.cc",
"common_runtime/rendezvous_mgr.cc",
"common_runtime/rendezvous_util.cc",
"common_runtime/resource_variable_read_optimizer.cc",
"common_runtime/session.cc",
"common_runtime/session_factory.cc",
@ -1828,6 +1829,7 @@ tf_cuda_library(
"common_runtime/profile_handler.h",
"common_runtime/renamed_device.h",
"common_runtime/rendezvous_mgr.h",
"common_runtime/rendezvous_util.h",
"common_runtime/session_factory.h",
"common_runtime/graph_execution_state.h",
"common_runtime/placer.h",
@ -2669,29 +2671,29 @@ tf_cc_test(
srcs = ["common_runtime/process_function_library_runtime_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":core",
":core_cpu",
":core_cpu_internal",
":direct_session_internal",
":framework",
":framework_internal",
":lib",
":lib_internal",
":ops",
":protos_all_cc",
":test",
":test_main",
":testlib",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:matmul_op",
"//tensorflow/core/kernels:shape_ops",
"//third_party/eigen3",
],
)
tf_cc_test(
name = "common_runtime_rendezvous_util_test",
size = "small",
srcs = ["common_runtime/rendezvous_util_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":core_cpu_internal",
":lib",
":test",
":test_main",
],
)

View File

@ -213,6 +213,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
FunctionBody** g_body);
bool IsLocalTarget(const AttrSlice& attrs);
AttrValueMap FixAttrs(const AttrSlice& attrs);
void RunRemote(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
Executor::Args* exec_args, Item* item, DoneCallback done);
TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
};
@ -557,52 +560,130 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
return Status::OK();
}
void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
Executor::Args* exec_args,
Item* item, DoneCallback done) {
FunctionCallFrame* frame = exec_args->call_frame;
string target_device = parent_->GetDeviceName(handle);
string source_device = opts.source_device;
Rendezvous* rendezvous = opts.rendezvous;
// TODO(rohanj): Handle alloc_attrs in Rendezvous::Args.
Rendezvous::Args rendez_args;
Status s =
parent_->GetDeviceContext(target_device, &rendez_args.device_context);
if (!s.ok()) {
delete frame;
delete exec_args;
done(s);
return;
}
// The ProcFLR sends the arguments to the function from the source_device to
// the target_device. So here we receive those arguments. Similarly, when the
// computation is done and stored in *rets, we send the return values back
// to the source_device (caller) so that the ProcFLR can receive them later.
std::vector<Tensor>* remote_args = new std::vector<Tensor>;
ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
source_device, target_device, "arg_", args.size(), rendez_args,
rendezvous, remote_args,
[frame, remote_args, item, source_device, target_device, rendezvous,
rendez_args, rets, done, exec_args](const Status& status) {
Status s = status;
s = frame->SetArgs(*remote_args);
if (!s.ok()) {
delete frame;
delete remote_args;
delete exec_args;
done(s);
return;
}
item->exec->RunAsync(
*exec_args,
[item, frame, rets, done, source_device, target_device, rendezvous,
rendez_args, remote_args, exec_args](const Status& status) {
item->Unref();
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
}
delete frame;
if (!s.ok()) {
delete remote_args;
delete exec_args;
done(s);
return;
}
s = ProcessFunctionLibraryRuntime::SendTensors(
target_device, source_device, "ret_", *rets, rendez_args,
rendezvous);
delete remote_args;
delete exec_args;
done(s);
});
});
}
void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
DoneCallback done) {
if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
return done(errors::Cancelled(""));
done(errors::Cancelled(""));
return;
}
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
return parent_->Run(opts, handle, args, rets, done);
parent_->Run(opts, handle, args, rets, done);
return;
}
const FunctionBody* fbody = GetFunctionBody(handle);
FunctionCallFrame* frame =
new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
Status s = frame->SetArgs(args);
if (!s.ok()) {
delete frame;
return done(s);
}
Item* item = nullptr;
s = GetOrCreateItem(handle, &item);
Status s = GetOrCreateItem(handle, &item);
if (!s.ok()) {
delete frame;
return done(s);
done(s);
return;
}
DCHECK(opts.runner != nullptr);
Executor::Args exec_args;
Executor::Args* exec_args = new Executor::Args;
// Inherit the step_id from the caller.
exec_args.step_id = opts.step_id;
exec_args.rendezvous = opts.rendezvous;
exec_args.stats_collector = opts.stats_collector;
exec_args.call_frame = frame;
exec_args.cancellation_manager = opts.cancellation_manager;
exec_args.step_container = opts.step_container;
exec_args.runner = *opts.runner;
exec_args->step_id = opts.step_id;
exec_args->rendezvous = opts.rendezvous;
exec_args->stats_collector = opts.stats_collector;
exec_args->call_frame = frame;
exec_args->cancellation_manager = opts.cancellation_manager;
exec_args->step_container = opts.step_container;
exec_args->runner = *opts.runner;
if (opts.remote_execution) {
RunRemote(opts, handle, args, rets, exec_args, item, done);
return;
}
s = frame->SetArgs(args);
if (!s.ok()) {
delete frame;
delete exec_args;
done(s);
return;
}
item->exec->RunAsync(
// Executor args
exec_args,
*exec_args,
// Done callback.
[item, frame, rets, done](const Status& status) {
[item, frame, rets, done, exec_args](const Status& status) {
item->Unref();
Status s = status;
if (s.ok()) {
s = frame->GetRetvals(rets);
s = frame->ConsumeRetvals(rets);
}
delete frame;
delete exec_args;
done(s);
});
}

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h"
@ -155,6 +156,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime::Options opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
@ -164,7 +166,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
};
Notification done;
FunctionLibraryRuntime::Options opts;
opts.runner = &runner;
std::vector<Tensor> out;
Status status;
@ -205,7 +206,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
if (!status.ok()) {
return status;
}
return Run(flr, handle, args, std::move(rets));
FunctionLibraryRuntime::Options opts;
return Run(flr, handle, opts, args, std::move(rets));
}
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
@ -963,15 +965,21 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
{{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle));
Tensor y;
FunctionLibraryRuntime::Options opts;
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
opts.source_device = "/device:CPU:1";
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
TF_CHECK_OK(Run(flr1_, handle, {}, {&y}));
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
TensorShape({})));
TF_CHECK_OK(Run(flr2_, handle, {}, {&y}));
opts.remote_execution = true;
opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
TensorShape({})));
opts.rendezvous->Unref();
}
namespace {

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
namespace tensorflow {
@ -57,6 +58,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
}
}
/* static */
string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
const AttrSlice& attrs) {
const AttrValue* value;
@ -66,6 +68,63 @@ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
return value->s();
}
/* static */
Status ProcessFunctionLibraryRuntime::SendTensors(
const string& source_device, const string& target_device,
const string& key_prefix, gtl::ArraySlice<Tensor> tensors_to_send,
const Rendezvous::Args& args, Rendezvous* rendezvous) {
std::vector<string> keys;
for (int i = 0; i < tensors_to_send.size(); ++i) {
string name = strings::StrCat(key_prefix, i);
string key = Rendezvous::CreateKey(source_device, i, target_device, name,
FrameAndIter(0, 0));
keys.push_back(key);
}
TF_RETURN_IF_ERROR(
SendTensorsToRendezvous(rendezvous, args, keys, tensors_to_send));
return Status::OK();
}
/* static */
void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
const string& source_device, const string& target_device,
const string& key_prefix, int64 num_tensors, const Rendezvous::Args& args,
Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
const StatusCallback& done) {
std::vector<string> keys;
for (int64 i = 0; i < num_tensors; ++i) {
string name = strings::StrCat(key_prefix, i);
string key = Rendezvous::CreateKey(source_device, i, target_device, name,
FrameAndIter(0, 0));
keys.push_back(key);
}
RecvOutputsFromRendezvousAsync(
rendezvous, args, keys, received_tensors,
[done](const Status& status) { done(status); });
}
Status ProcessFunctionLibraryRuntime::GetDeviceContext(
const string& device_name, DeviceContext** device_context) {
*device_context = nullptr;
FunctionLibraryRuntime* flr = GetFLR(device_name);
if (flr == nullptr) {
return errors::InvalidArgument("Device name: ", device_name, " not found.");
}
Device* device = flr->device();
string device_type = device->parsed_name().type;
if (device_type == "CPU") return Status::OK();
if (device_type == "GPU") {
auto* dev_info = flr->device()->tensorflow_gpu_device_info();
if (dev_info) {
*device_context = dev_info->default_context;
return Status::OK();
}
}
return errors::Internal("Device type: ", device_type,
" is currently unsupported for remote ",
"function executions");
}
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
const string& device_name) {
if (flr_map_.find(device_name) == flr_map_.end()) {
@ -105,6 +164,7 @@ FunctionLibraryRuntime::LocalHandle
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
function_data_[handle];
if (p.first != device_name) {
@ -113,6 +173,15 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
return p.second;
}
string ProcessFunctionLibraryRuntime::GetDeviceName(
FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
function_data_[handle];
return p.first;
}
Status ProcessFunctionLibraryRuntime::Instantiate(
const string& function_name, AttrSlice attrs,
FunctionLibraryRuntime::Handle* handle) {
@ -129,15 +198,58 @@ void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
if (!opts.remote_execution) {
done(errors::InvalidArgument(
"ProcessFunctionLibraryRuntime::Run should only be called when there ",
"is a remote execution."));
return;
}
FunctionLibraryRuntime* flr = nullptr;
string target_device;
{
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
function_data_[handle];
target_device = p.first;
flr = GetFLR(p.first);
}
if (flr != nullptr) {
return flr->Run(opts, handle, args, rets, std::move(done));
auto rendezvous = opts.rendezvous;
string source_device = opts.source_device;
Rendezvous::Args rendez_args;
Status s = GetDeviceContext(source_device, &rendez_args.device_context);
if (!s.ok()) {
done(s);
return;
}
// Send the args over to the target device.
s = SendTensors(source_device, target_device, "arg_", args, rendez_args,
rendezvous);
if (!s.ok()) {
done(s);
return;
}
std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
flr->Run(opts, handle, args, remote_rets,
[source_device, target_device, rendezvous, remote_rets, rets, done,
rendez_args](const Status& status) {
if (!status.ok()) {
delete remote_rets;
done(status);
return;
}
int64 num_returns = remote_rets->size();
delete remote_rets;
// Now receive the return values from the target.
ReceiveTensorsAsync(target_device, source_device, "ret_",
num_returns, rendez_args, rendezvous, rets,
done);
});
} else {
done(errors::Internal("Could not find device"));
return;
}
}

View File

@ -45,6 +45,31 @@ class ProcessFunctionLibraryRuntime {
// attribute, returns "". Canonicalizes the device name.
static string ObtainFunctionTarget(const AttrSlice& attrs);
// Sends `tensors_to_send` from `source_device` to `target_device` using
// `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
// Rendezvous. Method takes references on each of the `tensors_to_send`.
// Method doesn't block.
static Status SendTensors(const string& source_device,
const string& target_device,
const string& key_prefix,
gtl::ArraySlice<Tensor> tensors_to_send,
const Rendezvous::Args& args,
Rendezvous* rendezvous);
typedef std::function<void(const Status&)> StatusCallback;
// Receives `received_tensors` from `target_device` (originally sent from
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
// keys to be retrieved. Method doesn't block and calls `done` when
// `num_tensors` are fetched.
static void ReceiveTensorsAsync(const string& source_device,
const string& target_device,
const string& key_prefix, int64 num_tensors,
const Rendezvous::Args& args,
Rendezvous* rendezvous,
std::vector<Tensor>* received_tensors,
const StatusCallback& done);
static const char kDefaultFLRDevice[];
// Returns the FunctionLibraryRuntime for the corresponding device_name.
FunctionLibraryRuntime* GetFLR(const string& device_name);
@ -85,6 +110,17 @@ class ProcessFunctionLibraryRuntime {
FunctionLibraryRuntime::DoneCallback done);
private:
// For a given device_name, returns a DeviceContext for copying
// tensors to/from the device.
Status GetDeviceContext(const string& device_name,
DeviceContext** device_context);
// Looks up the information for the given `handle` and returns the name
// of the device where the function is registered.
string GetDeviceName(FunctionLibraryRuntime::Handle handle);
friend class FunctionLibraryRuntimeImpl;
mutable mutex mu_;
// Holds all the function invocations here.

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/test.h"
@ -43,10 +44,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts));
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
}
Status Run(const string& name, test::function::Attrs attrs,
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
test::function::Attrs attrs, const std::vector<Tensor>& args,
std::vector<Tensor*> rets) {
FunctionLibraryRuntime::Handle handle;
Status status = proc_flr_->Instantiate(name, attrs, &handle);
if (!status.ok()) {
@ -61,7 +64,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
};
Notification done;
FunctionLibraryRuntime::Options opts;
opts.runner = &runner;
std::vector<Tensor> out;
proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
@ -86,6 +88,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
IntraProcessRendezvous* rendezvous_;
};
TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
@ -99,6 +102,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
EXPECT_EQ(flr->device(), devices_[1]);
flr = proc_flr_->GetFLR("abc");
EXPECT_EQ(flr, nullptr);
rendezvous_->Unref();
}
TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
@ -118,69 +122,94 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
Init({test::function::XTimesTwo()});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
TF_CHECK_OK(
Run("XTimesTwo",
Run("XTimesTwo", opts,
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
{&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
rendezvous_->Unref();
}
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
Init({test::function::FindDevice()});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}},
{}, {&y}));
TF_CHECK_OK(Run("FindDevice", opts,
{{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
TensorShape({})));
rendezvous_->Unref();
}
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
Init({test::function::XTimesTwo(), test::function::XTimesFour()});
auto x = test::AsTensor<float>({1, 2, 3, 4});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
TF_CHECK_OK(
Run("XTimesTwo",
Run("XTimesTwo", opts,
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
{&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
TF_CHECK_OK(
Run("XTimesFour",
Run("XTimesFour", opts,
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
{&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
rendezvous_->Unref();
}
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
Init({test::function::FindDevice()});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
{}, {&y}));
TF_CHECK_OK(Run("FindDevice", opts,
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
TensorShape({})));
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
{}, {&y}));
TF_CHECK_OK(Run("FindDevice", opts,
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
TensorShape({})));
rendezvous_->Unref();
}
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
Init({test::function::FindDevice()});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}},
{}, {&y}));
TF_CHECK_OK(Run("FindDevice", opts,
{{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
TensorShape({})));
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
{}, {&y}));
TF_CHECK_OK(Run("FindDevice", opts,
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
TensorShape({})));
rendezvous_->Unref();
}
} // anonymous namespace

View File

@ -0,0 +1,119 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/rendezvous_util.h"
namespace tensorflow {
Status SendTensorsToRendezvous(Rendezvous* rendezvous,
const Rendezvous::Args& args,
const std::vector<string>& keys,
gtl::ArraySlice<Tensor> tensors_to_send) {
if (keys.size() != tensors_to_send.size()) {
return errors::InvalidArgument(
"keys and tensors_to_send are not the same size. keys.size() = ",
keys.size(), "; tensors_to_send.size() = ", tensors_to_send.size());
}
Rendezvous::ParsedKey parsed;
for (int i = 0; i < keys.size(); ++i) {
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(keys[i], &parsed));
TF_RETURN_IF_ERROR(
rendezvous->Send(parsed, args, tensors_to_send[i], false));
}
return Status::OK();
}
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
const Rendezvous::Args& args,
const std::vector<string>& keys,
std::vector<Tensor>* received_tensors,
const StatusCallback& done) {
if (keys.empty()) {
done(Status::OK());
return;
}
received_tensors->reserve(keys.size());
std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> arguments;
for (int i = 0; i < keys.size(); ++i) {
Rendezvous::ParsedKey parsed;
Status s = Rendezvous::ParseKey(keys[i], &parsed);
received_tensors->push_back(Tensor());
if (!s.ok()) {
done(s);
return;
}
arguments.push_back(
std::make_tuple(keys[i], &((*received_tensors)[i]), parsed));
}
typedef struct {
mutex mu;
int64 done_counter;
Status shared_status = Status::OK();
} CallState;
CallState* call_state = new CallState;
call_state->done_counter = keys.size();
for (auto& p : arguments) {
const string& key = std::get<0>(p);
Tensor* val = std::get<1>(p);
Rendezvous::ParsedKey parsed = std::get<2>(p);
rendezvous->RecvAsync(
parsed, args,
[val, done, key, call_state](const Status& s,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& v, const bool is_dead) {
Status status = s;
if (status.ok()) {
*val = v;
if (is_dead) {
status = errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
}
call_state->mu.lock();
call_state->shared_status.Update(status);
call_state->done_counter--;
// If we are the last async call to return, call the done callback.
if (call_state->done_counter == 0) {
const Status& final_status = call_state->shared_status;
call_state->mu.unlock();
done(final_status);
delete call_state;
return;
}
call_state->mu.unlock();
});
}
}
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
const Rendezvous::Args& args) {
// Receives values requested by the caller.
Rendezvous::ParsedKey parsed;
for (auto& p : *out) {
const string& key = p.first;
Tensor* val = &p.second;
bool is_dead = false;
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
TF_RETURN_IF_ERROR(rendezvous->Recv(parsed, args, val, &is_dead));
if (is_dead) {
return errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,44 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
#include <map>
#include "tensorflow/core/framework/rendezvous.h"
namespace tensorflow {
typedef std::map<string, Tensor> NamedTensors;
typedef std::function<void(const Status&)> StatusCallback;
// Uses `rendezvous` to send tensors in `in`.
Status SendTensorsToRendezvous(Rendezvous* rendezvous,
const Rendezvous::Args& args,
const std::vector<string>& keys,
gtl::ArraySlice<Tensor> tensors_to_send);
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
const Rendezvous::Args& args,
const std::vector<string>& keys,
std::vector<Tensor>* received_tensors,
const StatusCallback& done);
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
const Rendezvous::Args& args);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_

View File

@ -0,0 +1,94 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class RendezvousUtilTest : public ::testing::Test {
public:
RendezvousUtilTest() { rendez_ = NewLocalRendezvous(); }
~RendezvousUtilTest() override { rendez_->Unref(); }
Rendezvous* rendez_;
};
// string -> Tensor<string>
Tensor V(const string& content) {
Tensor tensor(DT_STRING, TensorShape({}));
tensor.scalar<string>()() = content;
return tensor;
}
// Tensor<string> -> string
string V(const Tensor& tensor) {
CHECK_EQ(tensor.dtype(), DT_STRING);
CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
return tensor.scalar<string>()();
}
string MakeStringKey(const string& name) {
return Rendezvous::CreateKey(
"/job:localhost/replica:0/task:0/device:CPU:0", 0,
"/job:localhost/replica:0/task:0/device:GPU:0", name, FrameAndIter(0, 0));
}
TEST_F(RendezvousUtilTest, SendBeforeRecv) {
// Fire off sends before receive the tensors.
Rendezvous::Args args;
TF_ASSERT_OK(SendTensorsToRendezvous(
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
{V("hello1"), V("hello2")}));
Notification n;
std::vector<Tensor> received_keys;
RecvOutputsFromRendezvousAsync(
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
&received_keys, [&n](const Status& status) { n.Notify(); });
n.WaitForNotification();
EXPECT_EQ(2, received_keys.size());
EXPECT_EQ("hello1", V(received_keys[0]));
EXPECT_EQ("hello2", V(received_keys[1]));
}
TEST_F(RendezvousUtilTest, RecvBeforeSend) {
// Fire off recvs, wait for a notification in the callback.
Rendezvous::Args args;
Notification n;
std::vector<Tensor> received_keys;
RecvOutputsFromRendezvousAsync(
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
&received_keys, [&n](const Status& status) { n.Notify(); });
TF_ASSERT_OK(SendTensorsToRendezvous(
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
{V("hello1"), V("hello2")}));
n.WaitForNotification();
EXPECT_EQ(2, received_keys.size());
EXPECT_EQ("hello1", V(received_keys[0]));
EXPECT_EQ("hello2", V(received_keys[1]));
}
} // namespace
} // namespace tensorflow

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/framework/cancellation.h"
@ -321,116 +322,25 @@ Status GraphMgr::DeregisterAll() {
return Status::OK();
}
Status GraphMgr::SendInputsToRendezvous(Rendezvous* rendezvous,
const NamedTensors& in) {
Rendezvous::ParsedKey parsed;
for (const auto& p : in) {
const string& key = p.first;
const Tensor& val = p.second;
Status s = Rendezvous::ParseKey(key, &parsed);
if (s.ok()) {
s = rendezvous->Send(parsed, Rendezvous::Args(), val, false);
}
if (!s.ok()) {
return s;
}
}
return Status::OK();
}
Status GraphMgr::RecvOutputsFromRendezvous(Rendezvous* rendezvous,
NamedTensors* out) {
// Receives values requested by the caller.
Rendezvous::ParsedKey parsed;
for (auto& p : *out) {
const string& key = p.first;
Tensor* val = &p.second;
bool is_dead = false;
Status s = Rendezvous::ParseKey(key, &parsed);
if (s.ok()) {
s = rendezvous->Recv(parsed, Rendezvous::Args(), val, &is_dead);
}
if (is_dead) {
s = errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
if (!s.ok()) return s;
}
return Status::OK();
}
void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
NamedTensors* out,
const StatusCallback& done) {
if (out->empty()) {
done(Status::OK());
return;
}
// We compute the args before calling RecvAsync because we need to ensure that
// out isn't being iterated over after done is called, since done deletes out.
std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> args;
for (auto& p : *out) {
Rendezvous::ParsedKey parsed;
Status s = Rendezvous::ParseKey(p.first, &parsed);
if (!s.ok()) {
done(s);
return;
}
args.push_back(std::make_tuple(p.first, &p.second, parsed));
}
typedef struct {
mutex mu;
int done_counter;
Status shared_status = Status::OK();
} CallState;
CallState* call_state = new CallState;
call_state->done_counter = out->size();
for (auto& p : args) {
const string& key = std::get<0>(p);
Tensor* val = std::get<1>(p);
Rendezvous::ParsedKey parsed = std::get<2>(p);
rendezvous->RecvAsync(
parsed, Rendezvous::Args(),
[val, done, key, call_state](const Status& s,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& v, const bool is_dead) {
Status status = s;
if (status.ok()) {
*val = v;
if (is_dead) {
status = errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
}
call_state->mu.lock();
call_state->shared_status.Update(status);
call_state->done_counter--;
// If we are the last async call to return, call the done callback.
if (call_state->done_counter == 0) {
const Status& final_status = call_state->shared_status;
call_state->mu.unlock();
done(final_status);
delete call_state;
return;
}
call_state->mu.unlock();
});
}
}
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = SendInputsToRendezvous(rendezvous, in);
std::vector<string> keys;
std::vector<Tensor> tensors_to_send;
keys.reserve(in.size());
tensors_to_send.reserve(in.size());
for (const auto& p : in) {
keys.push_back(p.first);
tensors_to_send.push_back(p.second);
}
Status s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys,
tensors_to_send);
rendezvous->Unref();
return s;
}
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = RecvOutputsFromRendezvous(rendezvous, out);
Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
rendezvous->Unref();
return s;
}
@ -438,11 +348,24 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
StatusCallback done) {
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
RecvOutputsFromRendezvousAsync(rendezvous, out,
[done, rendezvous](const Status s) {
rendezvous->Unref();
done(s);
});
std::vector<string> keys;
std::vector<Tensor>* received_keys = new std::vector<Tensor>;
keys.reserve(out->size());
received_keys->reserve(out->size());
for (const auto& p : *out) {
keys.push_back(p.first);
received_keys->push_back(p.second);
}
RecvOutputsFromRendezvousAsync(
rendezvous, Rendezvous::Args(), keys, received_keys,
[done, rendezvous, received_keys, out, keys](const Status s) {
rendezvous->Unref();
for (int i = 0; i < keys.size(); ++i) {
(*out)[keys[i]] = (*received_keys)[i];
}
delete received_keys;
done(s);
});
}
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
@ -484,7 +407,16 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
// Sends values specified by the caller.
if (s.ok()) {
s = SendInputsToRendezvous(rendezvous, in);
std::vector<string> keys;
std::vector<Tensor> tensors_to_send;
keys.reserve(in.size());
tensors_to_send.reserve(in.size());
for (auto& p : in) {
keys.push_back(p.first);
tensors_to_send.push_back(p.second);
}
s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys,
tensors_to_send);
}
if (!s.ok()) {

View File

@ -169,11 +169,6 @@ class GraphMgr {
void BuildCostModel(Item* item, StepStatsCollector* collector,
CostGraphDef* cost_graph);
Status SendInputsToRendezvous(Rendezvous* rendezvous, const NamedTensors& in);
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out);
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, NamedTensors* out,
const StatusCallback& done);
Status InitItem(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
const DebugOptions& debug_options, Item* item);

View File

@ -426,6 +426,10 @@ class FunctionLibraryRuntime {
StepStatsCollector* stats_collector = nullptr;
std::function<void(std::function<void()>)>* runner = nullptr;
// Parameters for remote function execution.
bool remote_execution = false;
string source_device = ""; // Fully specified device name.
};
typedef std::function<void(const Status&)> DoneCallback;
virtual void Run(const Options& opts, Handle handle,

View File

@ -292,7 +292,8 @@ class RemoteCallOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
AttrValueMap attr_values = func_->attr();
AttrValue v;
v.set_s(target->scalar<string>()());
const string& target_device = target->scalar<string>()();
v.set_s(target_device);
AddAttr("_target", v, &attr_values);
FunctionLibraryRuntime* lib = ctx->function_library();
@ -310,6 +311,11 @@ class RemoteCallOp : public AsyncOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
opts.runner = ctx->runner();
opts.source_device = lib->device()->name();
if (opts.source_device != target_device) {
opts.remote_execution = true;
}
opts.rendezvous = ctx->rendezvous();
std::vector<Tensor> args;
args.reserve(arguments.size());
for (const Tensor& argument : arguments) {
@ -334,10 +340,13 @@ class RemoteCallOp : public AsyncOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
};
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_CPU), RemoteCallOp);
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_GPU), RemoteCallOp);
REGISTER_KERNEL_BUILDER(
Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
REGISTER_KERNEL_BUILDER(
Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
#if TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_SYCL), RemoteCallOp);
REGISTER_KERNEL_BUILDER(
Name("RemoteCall").Device(DEVICE_SYCL).HostMemory("target"), RemoteCallOp);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -500,6 +500,54 @@ class FunctionalOpsTest(test.TestCase):
mul = sess.run(remote_op)
self.assertEqual(mul, [6])
def testRemoteFunctionCPUGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@function.Defun(dtypes.float32, dtypes.float32)
def _remote_fn(a, b):
return math_ops.multiply(a, b)
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
a = variables.Variable(2, dtype=dtypes.float32)
b = variables.Variable(3, dtype=dtypes.float32)
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
remote_op = functional_ops.remote_call(
args=[a, b],
Tout=[dtypes.float32],
f=_remote_fn,
target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
mul = sess.run(remote_op)
self.assertEqual(mul, 9.0)
def testRemoteFunctionGPUCPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@function.Defun(dtypes.float32, dtypes.float32)
def _remote_fn(a, b):
return math_ops.multiply(a, b)
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
a = variables.Variable(2, dtype=dtypes.float32)
b = variables.Variable(3, dtype=dtypes.float32)
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
remote_op = functional_ops.remote_call(
args=[a, b],
Tout=[dtypes.float32],
f=_remote_fn,
target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
mul = sess.run(remote_op)
self.assertEqual(mul, 9.0)
if __name__ == "__main__":
test.main()