Fix a bug for TPUStrategy with nested output.

PiperOrigin-RevId: 286473428
Change-Id: I02c44b36848a2edea3c94d02d672fcbe04100a03
This commit is contained in:
Ruoxin Sang 2019-12-19 15:18:15 -08:00 committed by TensorFlower Gardener
parent 45f2aab17f
commit 69bb090113
2 changed files with 34 additions and 1 deletions

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.util import nest
class InputIterationTest(test.TestCase, parameterized.TestCase):
@ -97,6 +98,37 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
results.append(output)
self._validate_outputs(results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy
],
mode=["eager"]))
def testNestedOutput(self, distribution):
dataset = self._get_dataset()
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
@def_function.function
def run(iterator):
def computation(x):
return [{
"a": x - 1,
"b": x + 1
}]
inputs = next(iterator)
outputs = distribution.experimental_run_v2(computation, args=(inputs,))
return nest.map_structure(distribution.experimental_local_results,
outputs)
results = run(input_iterator)
for replica in range(distribution.num_replicas_in_sync):
# The input dataset is range(10), so the replica id is same as input.
self.assertAllEqual(results[0]["a"][replica], [replica - 1])
self.assertAllEqual(results[0]["b"][replica], [replica + 1])
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies,

View File

@ -817,7 +817,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
# Remove all no ops that may have been added during 'tpu.replicate()'
if isinstance(result[0], list):
result[0] = [
output for output in result[0] if tensor_util.is_tensor(output)
output for output in result[0] if not isinstance(
output, ops.Operation)
]
# Workaround for `tpu.replicate` behaviour when single `Tensor` returned.