Add ability for functions to share rendezvous
The private `_shared_rendezvous` property allows the function to use the rendezvous of the parent. This is only needed in order to support code where raw send/recv operations are inserted and when functions are run in graph mode where they may not be inlined. PiperOrigin-RevId: 315319264 Change-Id: Ieb6b3924c51ccfd201b4693f3a499f883c7c0b71
This commit is contained in:
parent
3cfba9571b
commit
b0b763203e
tensorflow
core
python
@ -344,6 +344,8 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
||||
static constexpr const char* const kDeviceRetOp = "_DeviceRetval";
|
||||
static constexpr const char* const kIntsOnDeviceAttr =
|
||||
"experimental_ints_on_device";
|
||||
static constexpr const char* const kSharedRendezvousAttr =
|
||||
"shared_rendezvous";
|
||||
|
||||
static constexpr const char* const kGradientOp = "SymbolicGradient";
|
||||
static constexpr const char* const kFuncAttr = "f";
|
||||
|
@ -43,7 +43,8 @@ namespace tensorflow {
|
||||
PartitionedCallOp::PartitionedCallOp(OpKernelConstruction* ctx)
|
||||
: AsyncOpKernel(ctx),
|
||||
func_(new NameAttrList),
|
||||
config_proto_(new ConfigProto) {
|
||||
config_proto_(new ConfigProto),
|
||||
shared_rendezvous_(false) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, func_.get()));
|
||||
string deprecated_config_serialized;
|
||||
@ -139,6 +140,11 @@ Status PartitionedCallOp::FillOutputDevices(
|
||||
return errors::NotFound("Failed to find definition for function \"",
|
||||
func_->name(), "\"");
|
||||
}
|
||||
auto func_attrs = fdef->attr();
|
||||
auto attr = func_attrs.find(FunctionLibraryDefinition::kSharedRendezvousAttr);
|
||||
if (attr != func_attrs.end() && attr->second.b()) {
|
||||
shared_rendezvous_ = true;
|
||||
}
|
||||
|
||||
bool is_type_list;
|
||||
for (const OpDef::ArgDef& ret_def : fdef->signature().output_arg()) {
|
||||
@ -245,6 +251,9 @@ void PartitionedCallOp::RunFunction(FunctionLibraryRuntime::Handle handle,
|
||||
run_opts.source_device =
|
||||
lib->device() == nullptr ? "" : lib->device()->name();
|
||||
run_opts.allow_dead_tensors = true;
|
||||
if (shared_rendezvous_) {
|
||||
run_opts.rendezvous = ctx->rendezvous();
|
||||
}
|
||||
|
||||
std::vector<Tensor>* rets = new std::vector<Tensor>;
|
||||
const string& func_name = func_->name();
|
||||
|
@ -58,6 +58,7 @@ class PartitionedCallOp : public AsyncOpKernel {
|
||||
std::unique_ptr<NameAttrList> func_;
|
||||
std::unique_ptr<ConfigProto> config_proto_;
|
||||
string executor_type_;
|
||||
bool shared_rendezvous_;
|
||||
mutex mu_;
|
||||
// Cache the handle per FLR because this kernel may be instantiated for
|
||||
// a stateful op, different invocations of it may use different FLRs.
|
||||
|
@ -3062,6 +3062,7 @@ tf_gen_op_wrapper_private_py(name = "rnn_ops_gen")
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "sendrecv_ops_gen",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/core:sendrecv_ops_op_lib",
|
||||
],
|
||||
|
@ -462,6 +462,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:sendrecv_ops_gen",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
|
@ -521,6 +521,10 @@ class Function(object):
|
||||
self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
|
||||
python_function, input_signature)
|
||||
self._implements = experimental_implements
|
||||
# If `True`, the function uses the rendezvous of the parent. This is only
|
||||
# needed to support code where raw send/recv operations are inserted and
|
||||
# when functions are run in graph mode where they may not be inlined.
|
||||
self._shared_rendezvous = None
|
||||
self._autograph = autograph
|
||||
self._experimental_autograph_options = experimental_autograph_options
|
||||
self._experimental_relax_shapes = experimental_relax_shapes
|
||||
@ -629,6 +633,10 @@ class Function(object):
|
||||
if self._implements is not None:
|
||||
attributes = self._create_implements_attribute()
|
||||
|
||||
share = self._shared_rendezvous
|
||||
if share is not None:
|
||||
attributes[function_lib.SHARED_RENDEZVOUS_ATTRIBUTE_NAME] = share
|
||||
|
||||
if self._experimental_compile is not None:
|
||||
attributes.update(_XlaMustCompile=bool(self._experimental_compile))
|
||||
if self._experimental_compile:
|
||||
@ -698,7 +706,8 @@ class Function(object):
|
||||
self._stateless_fn._name = self._name # pylint: disable=protected-access
|
||||
|
||||
def _clone(self, python_function):
|
||||
return Function(
|
||||
"""Clone the function with different python function."""
|
||||
f = Function(
|
||||
python_function=(self._python_function
|
||||
if python_function is None else python_function),
|
||||
name=self._name,
|
||||
@ -709,6 +718,11 @@ class Function(object):
|
||||
experimental_relax_shapes=self._experimental_relax_shapes,
|
||||
experimental_compile=self._experimental_compile)
|
||||
|
||||
if self._shared_rendezvous:
|
||||
f._shared_rendezvous = self._shared_rendezvous # pylint: disable=protected-access
|
||||
|
||||
return f
|
||||
|
||||
def _decorate(self, decorator):
|
||||
"""Allows the captured Python function to be decorated in place.
|
||||
|
||||
@ -922,8 +936,8 @@ class Function(object):
|
||||
@function_lib.defun(autograph=False)
|
||||
def initialize_variables():
|
||||
op_map = object_identity.ObjectIdentityDictionary()
|
||||
# Stack all the var_is_initialized values into one tensor and interpret the
|
||||
# numpy value. This will reduce the number of RPCs between client and
|
||||
# Stack all the var_is_initialized values into one tensor and interpret
|
||||
# the numpy value. This will reduce the number of RPCs between client and
|
||||
# worker in the remote case.
|
||||
with ops.init_scope():
|
||||
var_is_initialized = []
|
||||
|
@ -86,6 +86,7 @@ ag_ctx = lazy_loader.LazyLoader(
|
||||
FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
|
||||
BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
|
||||
IMPLEMENTS_ATTRIBUTE_NAME = "_implements"
|
||||
SHARED_RENDEZVOUS_ATTRIBUTE_NAME = "shared_rendezvous"
|
||||
|
||||
|
||||
def _make_input_signature_hashable(elem, variable_map=None):
|
||||
|
@ -60,9 +60,11 @@ from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gen_functional_ops
|
||||
from tensorflow.python.ops import gen_random_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import gen_sendrecv_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import list_ops
|
||||
@ -858,6 +860,60 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
pool.map(stateful, [object() for _ in range(100)])
|
||||
self.assertEqual(float(v.read_value()), 0.0)
|
||||
|
||||
def testShareRendezvous(self):
|
||||
|
||||
# Disable grappler from inlining the functions. Note we run the send & recv
|
||||
# in graph mode since with eager mode the function should automatically be
|
||||
# inlined.
|
||||
context.context().set_optimizer_experimental_options(
|
||||
{'disable_meta_optimizer': True})
|
||||
|
||||
cpu = '/device:CPU:0'
|
||||
|
||||
signature = [tensor_spec.TensorSpec([], dtypes.int32)]
|
||||
|
||||
@def_function.function
|
||||
def send():
|
||||
x = constant_op.constant(1)
|
||||
gen_sendrecv_ops.send(x, 'x', cpu, 0, cpu)
|
||||
return x
|
||||
|
||||
send._shared_rendezvous = True # pylint: disable=protected-access
|
||||
|
||||
@def_function.function(input_signature=signature)
|
||||
def send_body(n):
|
||||
send()
|
||||
return n - 1
|
||||
|
||||
@def_function.function
|
||||
def recv():
|
||||
return gen_sendrecv_ops.recv(dtypes.int32, 'x', cpu, 0, cpu)
|
||||
|
||||
recv._shared_rendezvous = True # pylint: disable=protected-access
|
||||
|
||||
@def_function.function(input_signature=signature)
|
||||
def recv_body(n):
|
||||
recv()
|
||||
return n - 1
|
||||
|
||||
@def_function.function(input_signature=signature)
|
||||
def cond(n):
|
||||
return n > 0
|
||||
|
||||
# Instead of calling the send & recv functions directly we want to call them
|
||||
# through a functional while to ensure the rendezvous is shared across the
|
||||
# while boundary.
|
||||
@def_function.function
|
||||
def fn(n):
|
||||
functional_ops.While([n], cond.get_concrete_function(),
|
||||
send_body.get_concrete_function())
|
||||
return functional_ops.While([n], cond.get_concrete_function(),
|
||||
recv_body.get_concrete_function())
|
||||
|
||||
# Use a graph context since functions will not be automatically inlined
|
||||
with context.graph_mode(), self.cached_session():
|
||||
self.evaluate(fn(2))
|
||||
|
||||
def disabled_testRandomSeed(self):
|
||||
|
||||
@def_function.function
|
||||
|
Loading…
Reference in New Issue
Block a user