Parallel device: make tf.cond work executing eagerly
This takes the easy but not too satisfying "wrap it in a tf.function" approach, similar to what tf.vectorized_map does. This means we'll re-trace the cond's branches every time tf.cond runs. If this ends up being a performance bottleneck there are a few things we can do. One is to check if the condition parallel tensor is actually going to take different branches on different devices (and do the eager thing if not). Another is to tweak the calling code (e.g. BN) to wrap the cond itself in a tf.function; there we'll be able to cache the trace. We could also implement cond in the parallel device, with null optionals if a device isn't taking a branch; that seems pretty complicated. PiperOrigin-RevId: 345741960 Change-Id: Iaa543e03a2dab96dc0fa0cd453f48718f42d31a8
This commit is contained in:
parent
023a9b14f8
commit
6b4ba7fb16
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
import os
|
||||
import threading
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.distribute.parallel_device import parallel_device
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
@ -31,6 +33,7 @@ from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -119,7 +122,7 @@ class _VirtualDeviceTestCase(test.TestCase):
|
||||
self.assertIn(self.device_type + ":1", self.device.components[1])
|
||||
|
||||
|
||||
class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||
class ParallelDeviceTests(_VirtualDeviceTestCase, parameterized.TestCase):
|
||||
|
||||
def test_register_parallel_device(self):
|
||||
with self.device:
|
||||
@ -191,6 +194,47 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||
context._reset_context()
|
||||
config.set_synchronous_execution(previous)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
[("RunFunctionsEagerly", True),
|
||||
("", False)])
|
||||
def test_cond(self, run_functions_eagerly):
|
||||
try:
|
||||
def_function.run_functions_eagerly(run_functions_eagerly)
|
||||
with self.device:
|
||||
pred = self.device.pack([True, False])
|
||||
capture = self.device.pack([[1.], [2.]])
|
||||
result = control_flow_ops.cond(
|
||||
pred,
|
||||
def_function.function(lambda: capture * 2.),
|
||||
def_function.function(lambda: capture * 4.))
|
||||
self.assertAllClose(
|
||||
[[2.], [8.]], self.device.unpack(result))
|
||||
finally:
|
||||
def_function.run_functions_eagerly(False)
|
||||
|
||||
def test_cond_with_variable(self):
|
||||
with self.device:
|
||||
pred = self.device.pack([True, False])
|
||||
capture = self.device.pack([[1.], [2.]])
|
||||
v = None
|
||||
@def_function.function
|
||||
def true_branch():
|
||||
nonlocal v
|
||||
if v is None:
|
||||
v = variables.Variable(constant_op.constant(2.))
|
||||
return v * capture
|
||||
result = control_flow_ops.cond(
|
||||
pred, true_branch, def_function.function(lambda: capture * 4.))
|
||||
self.assertAllClose(
|
||||
[[2.], [8.]], self.device.unpack(result))
|
||||
self.assertAllClose(
|
||||
[2., 2.], self.device.unpack(v))
|
||||
# There are two unique variable handles with separate storage.
|
||||
h1, _ = self.device.unpack(v.handle)
|
||||
gen_resource_variable_ops.assign_variable_op(h1, constant_op.constant(3.))
|
||||
self.assertAllClose(
|
||||
[3., 2.], self.device.unpack(v))
|
||||
|
||||
def test_collective_in_function(self):
|
||||
if self.device_type == "TPU":
|
||||
self.skipTest("ParallelDevice collectives on TPUs need work")
|
||||
|
@ -72,6 +72,12 @@ cond_v2 = LazyLoader("cond_v2", globals(),
|
||||
while_v2 = LazyLoader("while_v2", globals(),
|
||||
"tensorflow.python.ops.while_v2")
|
||||
|
||||
# def_function also uses cond
|
||||
def_function = LazyLoader(
|
||||
"def_function", globals(),
|
||||
"tensorflow.python.eager.def_function")
|
||||
|
||||
|
||||
# We override the 'tuple' for a control flow op, so we keep python's
|
||||
# existing 'tuple' for later use in this module.
|
||||
_basetuple = tuple
|
||||
@ -1095,6 +1101,49 @@ def _UnpackIfSingleton(res):
|
||||
return res
|
||||
|
||||
|
||||
def _eager_cond_implementation(pred, true_fn, false_fn, strict, name):
|
||||
"""Special cases for `cond` when executing eagerly."""
|
||||
pred = ops.convert_to_tensor(pred)
|
||||
pred_constant_value = tensor_util.constant_value(pred)
|
||||
if pred_constant_value is None:
|
||||
# Eager tensors from a parallel device may not have a constant
|
||||
# value. Running the cond op itself would work, but we don't have logic to
|
||||
# build cond ops without wrapping in a function first.
|
||||
if (not isinstance(true_fn, def_function.Function)
|
||||
or not isinstance(false_fn, def_function.Function)):
|
||||
raise TypeError("When running tf.cond on a parallel device, `true_fn` "
|
||||
"and `false_fn` must be decorated with `tf.function`.")
|
||||
@def_function.function
|
||||
def _parallel_device_cond_wrapper():
|
||||
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
|
||||
functions_run_eagerly = def_function.functions_run_eagerly()
|
||||
if functions_run_eagerly:
|
||||
# We need to use tf.function to deal with variable creation inside the
|
||||
# cond, and skipping it because of run_functions_eagerly would just
|
||||
# crash immediately.
|
||||
logging.warning(
|
||||
"It looks like tf.function behavior was disabled, perhaps using "
|
||||
"tf.config.run_functions_eagerly. Parallelized tf.cond requires "
|
||||
"tf.function to work. This primitive will override the disable.")
|
||||
def_function.run_functions_eagerly(False)
|
||||
try:
|
||||
return _parallel_device_cond_wrapper()
|
||||
finally:
|
||||
if functions_run_eagerly is not None:
|
||||
def_function.run_functions_eagerly(functions_run_eagerly)
|
||||
else:
|
||||
# For conditions which are eager tensors with a constant value (most of
|
||||
# them), we only call the relevant branch function and execute it eagerly.
|
||||
with ops.name_scope(name, "cond", [pred]):
|
||||
if pred_constant_value:
|
||||
result = true_fn()
|
||||
else:
|
||||
result = false_fn()
|
||||
if not strict:
|
||||
result = _UnpackIfSingleton(result)
|
||||
return result
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
# pylint: disable=g-doc-args
|
||||
@tf_export(v1=["cond"])
|
||||
@ -1174,11 +1223,6 @@ def cond(pred,
|
||||
```
|
||||
|
||||
"""
|
||||
# Always enable control flow v2 if building a function, regardless of toggle.
|
||||
if (util.EnableControlFlowV2(ops.get_default_graph()) and
|
||||
not context.executing_eagerly()):
|
||||
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
|
||||
|
||||
# We needed to make true_fn/false_fn keyword arguments for
|
||||
# backwards-compatibility. This check exists so that we can convert back to
|
||||
# having them be positional arguments.
|
||||
@ -1202,16 +1246,14 @@ def cond(pred,
|
||||
if not callable(false_fn):
|
||||
raise TypeError("false_fn must be callable.")
|
||||
|
||||
with ops.name_scope(name, "cond", [pred]):
|
||||
if context.executing_eagerly():
|
||||
if pred:
|
||||
result = true_fn()
|
||||
else:
|
||||
result = false_fn()
|
||||
if not strict:
|
||||
result = _UnpackIfSingleton(result)
|
||||
return result
|
||||
if context.executing_eagerly():
|
||||
return _eager_cond_implementation(pred, true_fn, false_fn, strict, name)
|
||||
|
||||
# Always enable control flow v2 if building a function, regardless of toggle.
|
||||
if util.EnableControlFlowV2(ops.get_default_graph()):
|
||||
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
|
||||
|
||||
with ops.name_scope(name, "cond", [pred]):
|
||||
# Add the Switch to the graph.
|
||||
if isinstance(pred, bool):
|
||||
raise TypeError("pred must not be a Python bool")
|
||||
|
Loading…
x
Reference in New Issue
Block a user