Support passing and returning Nones in TPUStrategy.

This is supported in non-TPU strategies, and surprises users trying to migrate.
The lower-level input and output replication ops can't handle values that aren't convertible to Tensor, so we need to do some massaging around this. Nones in inputs are temporarily replaced with a constant, then replaced after replication. Nones in outputs are simply not replicated, as output replication happens at a per-value granularity.

PiperOrigin-RevId: 355690080
Change-Id: I9d2435e953c8feb7818a882cb5280327f310c919
This commit is contained in:
Revan Sopher 2021-02-04 12:47:10 -08:00 committed by TensorFlower Gardener
parent 581a3825f8
commit a0bd36e7f4
3 changed files with 64 additions and 22 deletions

View File

@ -1439,15 +1439,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
if kwargs is None:
kwargs = {}
# Remove None at the end of args as they are not replicatable
# If there are None in the middle we can't do anything about it
# so let those cases fail.
# For example when Keras model predict is used they pass the targets as
# None. We want to handle it here so all client libraries don't have to
# do this as other strategies can handle None values better.
while args and args[-1] is None:
args = args[:-1]
# Used to re-structure flattened output tensors from `tpu.replicate()`
# into a structured format.
result = [[]]

View File

@ -635,6 +635,26 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1",
results[1].backing_device)
def test_run_passing_and_returning_nones(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def train_step():
def computation(x):
return x
# Note that this input None is nested.
outputs = strategy.experimental_local_results(
strategy.run(computation, args=([1, [2, None]],)))
return outputs
results = train_step()
self.assertAllEqual(1, results[0][0].values[0])
self.assertAllEqual(2, results[0][1][0].values[0])
self.assertIsNone(results[0][1][1])
def test_composite_input_output(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
if strategy.num_replicas_in_sync != 2:

View File

@ -39,6 +39,7 @@ from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -1172,7 +1173,7 @@ def _flatten_and_filter_composite(maybe_composite, non_composite_output,
def split_compile_and_replicate(
computation: Callable[..., Any],
inputs: List[List[Optional[core_types.Tensor]]] = None,
inputs: Optional[List[List[core_types.Tensor]]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
name: Optional[Text] = None,
@ -1275,8 +1276,9 @@ def split_compile_and_replicate(
for i in xrange(1, num_replicas):
nest.assert_same_structure(inputs[0], inputs[i])
# Flatten inputs.
flat_inputs = [
# Flatten inputs. This structure may contain None values, which will be
# handled later.
flat_inputs_with_nones = [
nest.flatten(per_replica_input, expand_composites=True)
for per_replica_input in inputs
]
@ -1285,9 +1287,14 @@ def split_compile_and_replicate(
is_composite = nest.flatten(nest.map_structure(
lambda x: _flatten_and_filter_composite(x, False, True), inputs[0]))
# Converts inputs to Tensors.
flat_inputs = [[ops.convert_to_tensor(x) for x in inp]
for inp in flat_inputs]
# Converts inputs to Tensors, replacing Nones with a placeholder 0 since
# tpu_ops.tpu_replicated_input() can't handle non-Tensor values.
flat_inputs = []
for inp in flat_inputs_with_nones:
flat_inputs.append([
constant_op.constant(0) if x is None else ops.convert_to_tensor(x)
for x in inp
])
# Verifies that all replicas have matching numbers and types of inputs
flat_input_types = [x.dtype for x in flat_inputs[0]]
@ -1426,10 +1433,16 @@ def split_compile_and_replicate(
attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access
# Clobber replicated placeholders with Nones.
computation_inputs = [
None if inp is None else replicated for replicated, inp in zip(
flat_replicated_inputs, flat_inputs_with_nones[0])
]
# Unflatten the computation inputs to match original input structure.
computation_inputs = nest.pack_sequence_as(
structure=inputs[0],
flat_sequence=flat_replicated_inputs[:flat_input_arity],
flat_sequence=computation_inputs[:flat_input_arity],
expand_composites=True)
# If there is an infeed queue, adds the dequeued values to the
@ -1525,8 +1538,18 @@ def split_compile_and_replicate(
]
# Fan-out: Builds a TPUReplicatedOutput node for each output.
replicated_outputs = [[] for i in xrange(num_replicas)]
replicated_outputs = [[] for i in range(num_replicas)]
for i, t in enumerate(output_tensors):
# None values returned by the computation can't be sent to
# tpu_ops.tpu_replicated_output(), we handle them specially here. We can
# avoid the placeholder 0 routine required on the inputs since outputs are
# replicated per-tensor, not per-replica, so we can skip replication.
if t is None:
for replica in range(num_replicas):
replicated_outputs[replica].append(None)
continue
# Fan-out: Builds a TPUReplicatedOutput node for each output.
ys = tpu_ops.tpu_replicated_output(
t, num_replicas, name="output{}".format(i))
@ -1534,7 +1557,7 @@ def split_compile_and_replicate(
# Wraps the outputs in identity operators so the names of any possible
# `fetch` nodes are preserved by the replication rewrite.
with ops.control_dependencies(control_deps):
for replica in xrange(num_replicas):
for replica in range(num_replicas):
replicated_outputs[replica].append(
array_ops.identity(
ys[replica], name="output_%d_shard_%d" % (i, replica)))
@ -1549,7 +1572,7 @@ def split_compile_and_replicate(
def _postprocess_flat_outputs(
outputs: Any
) -> Tuple[List[core_types.Tensor], List[ops.Operation], List[Any]]:
) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
"""Validates non-flat outputs, add backs device assignments and other attrs.
Args:
@ -1584,10 +1607,12 @@ def _postprocess_flat_outputs(
# Append `no_op` here so that fetching any return value of this function
# will trigger TPUExecute node.
outputs += (control_flow_ops.no_op(),)
maybe_convert = lambda x: None if x is None else ops.convert_to_tensor(x)
try:
with ops.device(core(0)):
outputs = [
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
o if isinstance(o, ops.Operation) else maybe_convert(o)
for o in outputs
]
except Exception as e:
@ -1616,6 +1641,8 @@ def _postprocess_flat_outputs(
# TODO(phawkins): extend the rewrite to elide these nodes instead.
new_output_tensors = []
for t in output_tensors:
if t is None:
new_output_tensors.append(None)
with ops.device(t.device if t.device else core(0)):
o = array_ops.identity(t)
# pylint: disable=protected-access
@ -1627,7 +1654,7 @@ def _postprocess_flat_outputs(
def _postprocess_non_flat_outputs(
outputs: Any
) -> Tuple[List[core_types.Tensor], List[ops.Operation], List[Any]]:
) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
"""Validates non-flat outputs, add backs device assignments and other attrs.
Args:
@ -1643,8 +1670,12 @@ def _postprocess_non_flat_outputs(
# Flatten output items.
flat_outputs = nest.flatten(outputs, expand_composites=True)
# Convert all non-Operation outputs to Tensors.
# Convert all non-None non-Operation outputs to Tensors.
for i, o in enumerate(flat_outputs):
if o is None:
flat_outputs[i] = None
continue
if isinstance(o, ops.Operation):
raise ValueError(
"tpu.rewrite does not support Operation as return value in non-flat "