Parallel device: make tf.cond work executing eagerly

This takes the easy but not too satisfying "wrap it in a tf.function" approach, similar to what tf.vectorized_map does. This means we'll re-trace the cond's branches every time tf.cond runs.

If this ends up being a performance bottleneck there are a few things we can do. One is to check if the condition parallel tensor is actually going to take different branches on different devices (and do the eager thing if not). Another is to tweak the calling code (e.g. BN) to wrap the cond itself in a tf.function; there we'll be able to cache the trace. We could also implement cond in the parallel device, with null optionals if a device isn't taking a branch; that seems pretty complicated.

PiperOrigin-RevId: 345741960
Change-Id: Iaa543e03a2dab96dc0fa0cd453f48718f42d31a8
This commit is contained in:
Allen Lavoie 2020-12-04 13:12:29 -08:00 committed by TensorFlower Gardener
parent 023a9b14f8
commit 6b4ba7fb16
2 changed files with 101 additions and 15 deletions

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import os
import threading
from absl.testing import parameterized
from tensorflow.python.distribute.parallel_device import parallel_device
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@ -31,6 +33,7 @@ from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -119,7 +122,7 @@ class _VirtualDeviceTestCase(test.TestCase):
self.assertIn(self.device_type + ":1", self.device.components[1])
class ParallelDeviceTests(_VirtualDeviceTestCase):
class ParallelDeviceTests(_VirtualDeviceTestCase, parameterized.TestCase):
def test_register_parallel_device(self):
with self.device:
@ -191,6 +194,47 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
context._reset_context()
config.set_synchronous_execution(previous)
@parameterized.named_parameters(
[("RunFunctionsEagerly", True),
("", False)])
def test_cond(self, run_functions_eagerly):
try:
def_function.run_functions_eagerly(run_functions_eagerly)
with self.device:
pred = self.device.pack([True, False])
capture = self.device.pack([[1.], [2.]])
result = control_flow_ops.cond(
pred,
def_function.function(lambda: capture * 2.),
def_function.function(lambda: capture * 4.))
self.assertAllClose(
[[2.], [8.]], self.device.unpack(result))
finally:
def_function.run_functions_eagerly(False)
def test_cond_with_variable(self):
with self.device:
pred = self.device.pack([True, False])
capture = self.device.pack([[1.], [2.]])
v = None
@def_function.function
def true_branch():
nonlocal v
if v is None:
v = variables.Variable(constant_op.constant(2.))
return v * capture
result = control_flow_ops.cond(
pred, true_branch, def_function.function(lambda: capture * 4.))
self.assertAllClose(
[[2.], [8.]], self.device.unpack(result))
self.assertAllClose(
[2., 2.], self.device.unpack(v))
# There are two unique variable handles with separate storage.
h1, _ = self.device.unpack(v.handle)
gen_resource_variable_ops.assign_variable_op(h1, constant_op.constant(3.))
self.assertAllClose(
[3., 2.], self.device.unpack(v))
def test_collective_in_function(self):
if self.device_type == "TPU":
self.skipTest("ParallelDevice collectives on TPUs need work")

View File

@ -72,6 +72,12 @@ cond_v2 = LazyLoader("cond_v2", globals(),
while_v2 = LazyLoader("while_v2", globals(),
"tensorflow.python.ops.while_v2")
# def_function also uses cond
def_function = LazyLoader(
"def_function", globals(),
"tensorflow.python.eager.def_function")
# We override the 'tuple' for a control flow op, so we keep python's
# existing 'tuple' for later use in this module.
_basetuple = tuple
@ -1095,6 +1101,49 @@ def _UnpackIfSingleton(res):
return res
def _eager_cond_implementation(pred, true_fn, false_fn, strict, name):
"""Special cases for `cond` when executing eagerly."""
pred = ops.convert_to_tensor(pred)
pred_constant_value = tensor_util.constant_value(pred)
if pred_constant_value is None:
# Eager tensors from a parallel device may not have a constant
# value. Running the cond op itself would work, but we don't have logic to
# build cond ops without wrapping in a function first.
if (not isinstance(true_fn, def_function.Function)
or not isinstance(false_fn, def_function.Function)):
raise TypeError("When running tf.cond on a parallel device, `true_fn` "
"and `false_fn` must be decorated with `tf.function`.")
@def_function.function
def _parallel_device_cond_wrapper():
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
functions_run_eagerly = def_function.functions_run_eagerly()
if functions_run_eagerly:
# We need to use tf.function to deal with variable creation inside the
# cond, and skipping it because of run_functions_eagerly would just
# crash immediately.
logging.warning(
"It looks like tf.function behavior was disabled, perhaps using "
"tf.config.run_functions_eagerly. Parallelized tf.cond requires "
"tf.function to work. This primitive will override the disable.")
def_function.run_functions_eagerly(False)
try:
return _parallel_device_cond_wrapper()
finally:
if functions_run_eagerly is not None:
def_function.run_functions_eagerly(functions_run_eagerly)
else:
# For conditions which are eager tensors with a constant value (most of
# them), we only call the relevant branch function and execute it eagerly.
with ops.name_scope(name, "cond", [pred]):
if pred_constant_value:
result = true_fn()
else:
result = false_fn()
if not strict:
result = _UnpackIfSingleton(result)
return result
# pylint: disable=redefined-outer-name
# pylint: disable=g-doc-args
@tf_export(v1=["cond"])
@ -1174,11 +1223,6 @@ def cond(pred,
```
"""
# Always enable control flow v2 if building a function, regardless of toggle.
if (util.EnableControlFlowV2(ops.get_default_graph()) and
not context.executing_eagerly()):
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
# We needed to make true_fn/false_fn keyword arguments for
# backwards-compatibility. This check exists so that we can convert back to
# having them be positional arguments.
@ -1202,16 +1246,14 @@ def cond(pred,
if not callable(false_fn):
raise TypeError("false_fn must be callable.")
with ops.name_scope(name, "cond", [pred]):
if context.executing_eagerly():
if pred:
result = true_fn()
else:
result = false_fn()
if not strict:
result = _UnpackIfSingleton(result)
return result
if context.executing_eagerly():
return _eager_cond_implementation(pred, true_fn, false_fn, strict, name)
# Always enable control flow v2 if building a function, regardless of toggle.
if util.EnableControlFlowV2(ops.get_default_graph()):
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
with ops.name_scope(name, "cond", [pred]):
# Add the Switch to the graph.
if isinstance(pred, bool):
raise TypeError("pred must not be a Python bool")