diff --git a/tensorflow/python/distribute/custom_training_loop_test.py b/tensorflow/python/distribute/custom_training_loop_test.py index 53af0c73b0b..55cb4587a73 100644 --- a/tensorflow/python/distribute/custom_training_loop_test.py +++ b/tensorflow/python/distribute/custom_training_loop_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables +from tensorflow.python.util import nest class InputIterationTest(test.TestCase, parameterized.TestCase): @@ -97,6 +98,37 @@ class InputIterationTest(test.TestCase, parameterized.TestCase): results.append(output) self._validate_outputs(results) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy + ], + mode=["eager"])) + def testNestedOutput(self, distribution): + dataset = self._get_dataset() + input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) + + @def_function.function + def run(iterator): + + def computation(x): + return [{ + "a": x - 1, + "b": x + 1 + }] + + inputs = next(iterator) + outputs = distribution.experimental_run_v2(computation, args=(inputs,)) + return nest.map_structure(distribution.experimental_local_results, + outputs) + + results = run(input_iterator) + for replica in range(distribution.num_replicas_in_sync): + # The input dataset is range(10), so the replica id is same as input. + self.assertAllEqual(results[0]["a"][replica], [replica - 1]) + self.assertAllEqual(results[0]["b"][replica], [replica + 1]) + @combinations.generate( combinations.combine( distribution=strategy_combinations.all_strategies, diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 85ff25439c3..f6967d858aa 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -817,7 +817,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ - output for output in result[0] if tensor_util.is_tensor(output) + output for output in result[0] if not isinstance( + output, ops.Operation) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.