Changed gradient of dynamic_partition to use parallel_dynamic_stitch.
PiperOrigin-RevId: 266189934
This commit is contained in:
parent
9b9bea6515
commit
062f659524
@ -39,7 +39,8 @@ def _DynamicPartitionGrads(op, *grads):
|
||||
math_ops.range(math_ops.reduce_prod(prefix_shape)), prefix_shape)
|
||||
partitioned_indices = data_flow_ops.dynamic_partition(
|
||||
original_indices, indices, num_partitions)
|
||||
reconstructed = data_flow_ops.dynamic_stitch(partitioned_indices, grads)
|
||||
reconstructed = data_flow_ops.parallel_dynamic_stitch(partitioned_indices,
|
||||
grads)
|
||||
reconstructed = array_ops.reshape(reconstructed, array_ops.shape(data))
|
||||
return [reconstructed, None]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user