diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index b611b2a8cd8..8c2d15dadf6 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -1439,15 +1439,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): if kwargs is None: kwargs = {} - # Remove None at the end of args as they are not replicatable - # If there are None in the middle we can't do anything about it - # so let those cases fail. - # For example when Keras model predict is used they pass the targets as - # None. We want to handle it here so all client libraries don't have to - # do this as other strategies can handle None values better. - while args and args[-1] is None: - args = args[:-1] - # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index a4b767ac03b..0601615b3a6 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -635,6 +635,26 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1", results[1].backing_device) + def test_run_passing_and_returning_nones(self, enable_packed_var): + strategy = get_tpu_strategy(enable_packed_var) + + @def_function.function + def train_step(): + + def computation(x): + return x + + # Note that this input None is nested. + outputs = strategy.experimental_local_results( + strategy.run(computation, args=([1, [2, None]],))) + return outputs + + results = train_step() + + self.assertAllEqual(1, results[0][0].values[0]) + self.assertAllEqual(2, results[0][1][0].values[0]) + self.assertIsNone(results[0][1][1]) + def test_composite_input_output(self, enable_packed_var): strategy = get_tpu_strategy(enable_packed_var) if strategy.num_replicas_in_sync != 2: diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index a6bc1a98913..4e15b80a3a6 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -39,6 +39,7 @@ from tensorflow.python.framework import auto_control_deps from tensorflow.python.framework import c_api_util from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import config +from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -1172,7 +1173,7 @@ def _flatten_and_filter_composite(maybe_composite, non_composite_output, def split_compile_and_replicate( computation: Callable[..., Any], - inputs: List[List[Optional[core_types.Tensor]]] = None, + inputs: Optional[List[List[core_types.Tensor]]] = None, infeed_queue: Optional[tpu_feed.InfeedQueue] = None, device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, name: Optional[Text] = None, @@ -1275,8 +1276,9 @@ def split_compile_and_replicate( for i in xrange(1, num_replicas): nest.assert_same_structure(inputs[0], inputs[i]) - # Flatten inputs. - flat_inputs = [ + # Flatten inputs. This structure may contain None values, which will be + # handled later. + flat_inputs_with_nones = [ nest.flatten(per_replica_input, expand_composites=True) for per_replica_input in inputs ] @@ -1285,9 +1287,14 @@ def split_compile_and_replicate( is_composite = nest.flatten(nest.map_structure( lambda x: _flatten_and_filter_composite(x, False, True), inputs[0])) - # Converts inputs to Tensors. - flat_inputs = [[ops.convert_to_tensor(x) for x in inp] - for inp in flat_inputs] + # Converts inputs to Tensors, replacing Nones with a placeholder 0 since + # tpu_ops.tpu_replicated_input() can't handle non-Tensor values. + flat_inputs = [] + for inp in flat_inputs_with_nones: + flat_inputs.append([ + constant_op.constant(0) if x is None else ops.convert_to_tensor(x) + for x in inp + ]) # Verifies that all replicas have matching numbers and types of inputs flat_input_types = [x.dtype for x in flat_inputs[0]] @@ -1426,10 +1433,16 @@ def split_compile_and_replicate( attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access + # Clobber replicated placeholders with Nones. + computation_inputs = [ + None if inp is None else replicated for replicated, inp in zip( + flat_replicated_inputs, flat_inputs_with_nones[0]) + ] + # Unflatten the computation inputs to match original input structure. computation_inputs = nest.pack_sequence_as( structure=inputs[0], - flat_sequence=flat_replicated_inputs[:flat_input_arity], + flat_sequence=computation_inputs[:flat_input_arity], expand_composites=True) # If there is an infeed queue, adds the dequeued values to the @@ -1525,8 +1538,18 @@ def split_compile_and_replicate( ] # Fan-out: Builds a TPUReplicatedOutput node for each output. - replicated_outputs = [[] for i in xrange(num_replicas)] + replicated_outputs = [[] for i in range(num_replicas)] for i, t in enumerate(output_tensors): + + # None values returned by the computation can't be sent to + # tpu_ops.tpu_replicated_output(), we handle them specially here. We can + # avoid the placeholder 0 routine required on the inputs since outputs are + # replicated per-tensor, not per-replica, so we can skip replication. + if t is None: + for replica in range(num_replicas): + replicated_outputs[replica].append(None) + continue + # Fan-out: Builds a TPUReplicatedOutput node for each output. ys = tpu_ops.tpu_replicated_output( t, num_replicas, name="output{}".format(i)) @@ -1534,7 +1557,7 @@ def split_compile_and_replicate( # Wraps the outputs in identity operators so the names of any possible # `fetch` nodes are preserved by the replication rewrite. with ops.control_dependencies(control_deps): - for replica in xrange(num_replicas): + for replica in range(num_replicas): replicated_outputs[replica].append( array_ops.identity( ys[replica], name="output_%d_shard_%d" % (i, replica))) @@ -1549,7 +1572,7 @@ def split_compile_and_replicate( def _postprocess_flat_outputs( outputs: Any - ) -> Tuple[List[core_types.Tensor], List[ops.Operation], List[Any]]: +) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]: """Validates non-flat outputs, add backs device assignments and other attrs. Args: @@ -1584,10 +1607,12 @@ def _postprocess_flat_outputs( # Append `no_op` here so that fetching any return value of this function # will trigger TPUExecute node. outputs += (control_flow_ops.no_op(),) + + maybe_convert = lambda x: None if x is None else ops.convert_to_tensor(x) try: with ops.device(core(0)): outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) + o if isinstance(o, ops.Operation) else maybe_convert(o) for o in outputs ] except Exception as e: @@ -1616,6 +1641,8 @@ def _postprocess_flat_outputs( # TODO(phawkins): extend the rewrite to elide these nodes instead. new_output_tensors = [] for t in output_tensors: + if t is None: + new_output_tensors.append(None) with ops.device(t.device if t.device else core(0)): o = array_ops.identity(t) # pylint: disable=protected-access @@ -1627,7 +1654,7 @@ def _postprocess_flat_outputs( def _postprocess_non_flat_outputs( outputs: Any - ) -> Tuple[List[core_types.Tensor], List[ops.Operation], List[Any]]: +) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]: """Validates non-flat outputs, add backs device assignments and other attrs. Args: @@ -1643,8 +1670,12 @@ def _postprocess_non_flat_outputs( # Flatten output items. flat_outputs = nest.flatten(outputs, expand_composites=True) - # Convert all non-Operation outputs to Tensors. + # Convert all non-None non-Operation outputs to Tensors. for i, o in enumerate(flat_outputs): + if o is None: + flat_outputs[i] = None + continue + if isinstance(o, ops.Operation): raise ValueError( "tpu.rewrite does not support Operation as return value in non-flat "