Support passing and returning Nones in TPUStrategy.
This is supported in non-TPU strategies, and surprises users trying to migrate. The lower-level input and output replication ops can't handle values that aren't convertible to Tensor, so we need to do some massaging around this. Nones in inputs are temporarily replaced with a constant, then replaced after replication. Nones in outputs are simply not replicated, as output replication happens at a per-value granularity. PiperOrigin-RevId: 355690080 Change-Id: I9d2435e953c8feb7818a882cb5280327f310c919
This commit is contained in:
parent
581a3825f8
commit
a0bd36e7f4
@ -1439,15 +1439,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# Remove None at the end of args as they are not replicatable
|
||||
# If there are None in the middle we can't do anything about it
|
||||
# so let those cases fail.
|
||||
# For example when Keras model predict is used they pass the targets as
|
||||
# None. We want to handle it here so all client libraries don't have to
|
||||
# do this as other strategies can handle None values better.
|
||||
while args and args[-1] is None:
|
||||
args = args[:-1]
|
||||
|
||||
# Used to re-structure flattened output tensors from `tpu.replicate()`
|
||||
# into a structured format.
|
||||
result = [[]]
|
||||
|
@ -635,6 +635,26 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1",
|
||||
results[1].backing_device)
|
||||
|
||||
def test_run_passing_and_returning_nones(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
|
||||
@def_function.function
|
||||
def train_step():
|
||||
|
||||
def computation(x):
|
||||
return x
|
||||
|
||||
# Note that this input None is nested.
|
||||
outputs = strategy.experimental_local_results(
|
||||
strategy.run(computation, args=([1, [2, None]],)))
|
||||
return outputs
|
||||
|
||||
results = train_step()
|
||||
|
||||
self.assertAllEqual(1, results[0][0].values[0])
|
||||
self.assertAllEqual(2, results[0][1][0].values[0])
|
||||
self.assertIsNone(results[0][1][1])
|
||||
|
||||
def test_composite_input_output(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
if strategy.num_replicas_in_sync != 2:
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.framework import auto_control_deps
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -1172,7 +1173,7 @@ def _flatten_and_filter_composite(maybe_composite, non_composite_output,
|
||||
|
||||
def split_compile_and_replicate(
|
||||
computation: Callable[..., Any],
|
||||
inputs: List[List[Optional[core_types.Tensor]]] = None,
|
||||
inputs: Optional[List[List[core_types.Tensor]]] = None,
|
||||
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
|
||||
name: Optional[Text] = None,
|
||||
@ -1275,8 +1276,9 @@ def split_compile_and_replicate(
|
||||
for i in xrange(1, num_replicas):
|
||||
nest.assert_same_structure(inputs[0], inputs[i])
|
||||
|
||||
# Flatten inputs.
|
||||
flat_inputs = [
|
||||
# Flatten inputs. This structure may contain None values, which will be
|
||||
# handled later.
|
||||
flat_inputs_with_nones = [
|
||||
nest.flatten(per_replica_input, expand_composites=True)
|
||||
for per_replica_input in inputs
|
||||
]
|
||||
@ -1285,9 +1287,14 @@ def split_compile_and_replicate(
|
||||
is_composite = nest.flatten(nest.map_structure(
|
||||
lambda x: _flatten_and_filter_composite(x, False, True), inputs[0]))
|
||||
|
||||
# Converts inputs to Tensors.
|
||||
flat_inputs = [[ops.convert_to_tensor(x) for x in inp]
|
||||
for inp in flat_inputs]
|
||||
# Converts inputs to Tensors, replacing Nones with a placeholder 0 since
|
||||
# tpu_ops.tpu_replicated_input() can't handle non-Tensor values.
|
||||
flat_inputs = []
|
||||
for inp in flat_inputs_with_nones:
|
||||
flat_inputs.append([
|
||||
constant_op.constant(0) if x is None else ops.convert_to_tensor(x)
|
||||
for x in inp
|
||||
])
|
||||
|
||||
# Verifies that all replicas have matching numbers and types of inputs
|
||||
flat_input_types = [x.dtype for x in flat_inputs[0]]
|
||||
@ -1426,10 +1433,16 @@ def split_compile_and_replicate(
|
||||
attr_value_pb2.AttrValue(b=True))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Clobber replicated placeholders with Nones.
|
||||
computation_inputs = [
|
||||
None if inp is None else replicated for replicated, inp in zip(
|
||||
flat_replicated_inputs, flat_inputs_with_nones[0])
|
||||
]
|
||||
|
||||
# Unflatten the computation inputs to match original input structure.
|
||||
computation_inputs = nest.pack_sequence_as(
|
||||
structure=inputs[0],
|
||||
flat_sequence=flat_replicated_inputs[:flat_input_arity],
|
||||
flat_sequence=computation_inputs[:flat_input_arity],
|
||||
expand_composites=True)
|
||||
|
||||
# If there is an infeed queue, adds the dequeued values to the
|
||||
@ -1525,8 +1538,18 @@ def split_compile_and_replicate(
|
||||
]
|
||||
|
||||
# Fan-out: Builds a TPUReplicatedOutput node for each output.
|
||||
replicated_outputs = [[] for i in xrange(num_replicas)]
|
||||
replicated_outputs = [[] for i in range(num_replicas)]
|
||||
for i, t in enumerate(output_tensors):
|
||||
|
||||
# None values returned by the computation can't be sent to
|
||||
# tpu_ops.tpu_replicated_output(), we handle them specially here. We can
|
||||
# avoid the placeholder 0 routine required on the inputs since outputs are
|
||||
# replicated per-tensor, not per-replica, so we can skip replication.
|
||||
if t is None:
|
||||
for replica in range(num_replicas):
|
||||
replicated_outputs[replica].append(None)
|
||||
continue
|
||||
|
||||
# Fan-out: Builds a TPUReplicatedOutput node for each output.
|
||||
ys = tpu_ops.tpu_replicated_output(
|
||||
t, num_replicas, name="output{}".format(i))
|
||||
@ -1534,7 +1557,7 @@ def split_compile_and_replicate(
|
||||
# Wraps the outputs in identity operators so the names of any possible
|
||||
# `fetch` nodes are preserved by the replication rewrite.
|
||||
with ops.control_dependencies(control_deps):
|
||||
for replica in xrange(num_replicas):
|
||||
for replica in range(num_replicas):
|
||||
replicated_outputs[replica].append(
|
||||
array_ops.identity(
|
||||
ys[replica], name="output_%d_shard_%d" % (i, replica)))
|
||||
@ -1549,7 +1572,7 @@ def split_compile_and_replicate(
|
||||
|
||||
def _postprocess_flat_outputs(
|
||||
outputs: Any
|
||||
) -> Tuple[List[core_types.Tensor], List[ops.Operation], List[Any]]:
|
||||
) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
|
||||
"""Validates non-flat outputs, add backs device assignments and other attrs.
|
||||
|
||||
Args:
|
||||
@ -1584,10 +1607,12 @@ def _postprocess_flat_outputs(
|
||||
# Append `no_op` here so that fetching any return value of this function
|
||||
# will trigger TPUExecute node.
|
||||
outputs += (control_flow_ops.no_op(),)
|
||||
|
||||
maybe_convert = lambda x: None if x is None else ops.convert_to_tensor(x)
|
||||
try:
|
||||
with ops.device(core(0)):
|
||||
outputs = [
|
||||
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
|
||||
o if isinstance(o, ops.Operation) else maybe_convert(o)
|
||||
for o in outputs
|
||||
]
|
||||
except Exception as e:
|
||||
@ -1616,6 +1641,8 @@ def _postprocess_flat_outputs(
|
||||
# TODO(phawkins): extend the rewrite to elide these nodes instead.
|
||||
new_output_tensors = []
|
||||
for t in output_tensors:
|
||||
if t is None:
|
||||
new_output_tensors.append(None)
|
||||
with ops.device(t.device if t.device else core(0)):
|
||||
o = array_ops.identity(t)
|
||||
# pylint: disable=protected-access
|
||||
@ -1627,7 +1654,7 @@ def _postprocess_flat_outputs(
|
||||
|
||||
def _postprocess_non_flat_outputs(
|
||||
outputs: Any
|
||||
) -> Tuple[List[core_types.Tensor], List[ops.Operation], List[Any]]:
|
||||
) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
|
||||
"""Validates non-flat outputs, add backs device assignments and other attrs.
|
||||
|
||||
Args:
|
||||
@ -1643,8 +1670,12 @@ def _postprocess_non_flat_outputs(
|
||||
# Flatten output items.
|
||||
flat_outputs = nest.flatten(outputs, expand_composites=True)
|
||||
|
||||
# Convert all non-Operation outputs to Tensors.
|
||||
# Convert all non-None non-Operation outputs to Tensors.
|
||||
for i, o in enumerate(flat_outputs):
|
||||
if o is None:
|
||||
flat_outputs[i] = None
|
||||
continue
|
||||
|
||||
if isinstance(o, ops.Operation):
|
||||
raise ValueError(
|
||||
"tpu.rewrite does not support Operation as return value in non-flat "
|
||||
|
Loading…
Reference in New Issue
Block a user