Fix a bug for TPUStrategy with nested output.
PiperOrigin-RevId: 286473428 Change-Id: I02c44b36848a2edea3c94d02d672fcbe04100a03
This commit is contained in:
parent
45f2aab17f
commit
69bb090113
@ -32,6 +32,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class InputIterationTest(test.TestCase, parameterized.TestCase):
|
||||
@ -97,6 +98,37 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
|
||||
results.append(output)
|
||||
self._validate_outputs(results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy
|
||||
],
|
||||
mode=["eager"]))
|
||||
def testNestedOutput(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
@def_function.function
|
||||
def run(iterator):
|
||||
|
||||
def computation(x):
|
||||
return [{
|
||||
"a": x - 1,
|
||||
"b": x + 1
|
||||
}]
|
||||
|
||||
inputs = next(iterator)
|
||||
outputs = distribution.experimental_run_v2(computation, args=(inputs,))
|
||||
return nest.map_structure(distribution.experimental_local_results,
|
||||
outputs)
|
||||
|
||||
results = run(input_iterator)
|
||||
for replica in range(distribution.num_replicas_in_sync):
|
||||
# The input dataset is range(10), so the replica id is same as input.
|
||||
self.assertAllEqual(results[0]["a"][replica], [replica - 1])
|
||||
self.assertAllEqual(results[0]["b"][replica], [replica + 1])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
|
@ -817,7 +817,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
# Remove all no ops that may have been added during 'tpu.replicate()'
|
||||
if isinstance(result[0], list):
|
||||
result[0] = [
|
||||
output for output in result[0] if tensor_util.is_tensor(output)
|
||||
output for output in result[0] if not isinstance(
|
||||
output, ops.Operation)
|
||||
]
|
||||
|
||||
# Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
|
||||
|
Loading…
x
Reference in New Issue
Block a user