Changed gradient of dynamic_partition to use parallel_dynamic_stitch.

PiperOrigin-RevId: 266189934
This commit is contained in:
A. Unique TensorFlower 2019-08-29 11:41:36 -07:00 committed by TensorFlower Gardener
parent 9b9bea6515
commit 062f659524

View File

@ -39,7 +39,8 @@ def _DynamicPartitionGrads(op, *grads):
math_ops.range(math_ops.reduce_prod(prefix_shape)), prefix_shape)
partitioned_indices = data_flow_ops.dynamic_partition(
original_indices, indices, num_partitions)
reconstructed = data_flow_ops.dynamic_stitch(partitioned_indices, grads)
reconstructed = data_flow_ops.parallel_dynamic_stitch(partitioned_indices,
grads)
reconstructed = array_ops.reshape(reconstructed, array_ops.shape(data))
return [reconstructed, None]