Update gradient of tf.broadcast_to to use tf.reduce_sum instead of tf.unsorted_segment_sum. Add new high rank tests for tf.broadcast_to and its gradient.

PiperOrigin-RevId: 233451123
This commit is contained in:
A. Unique TensorFlower 2019-02-11 11:59:29 -08:00 committed by TensorFlower Gardener
parent f35ec21e23
commit f5593ff762
2 changed files with 41 additions and 15 deletions

View File

@ -76,6 +76,26 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
v_np = np.broadcast_to(x, output_shape)
self.assertAllEqual(v_tf.eval(), v_np)
@test_util.run_deprecated_v1
def testBroadcastToShapeLargerDim(self):
input_shape = [2, 1, 3, 2, 2, 2]
output_shape = [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 15, 3, 2, 2, 2]
with self.cached_session(use_gpu=True):
x = np.array(np.random.randint(5, size=input_shape), dtype=np.int32)
v_tf = array_ops.broadcast_to(constant_op.constant(x), output_shape)
v_np = np.broadcast_to(x, output_shape)
self.assertAllEqual(v_tf.eval(), v_np)
@test_util.run_deprecated_v1
def testBroadcastToShapeLargerDim2(self):
input_shape = [2, 1, 3, 2, 2, 2, 1, 1, 1]
output_shape = [1, 1, 1, 2, 5, 3, 2, 2, 2, 3, 3, 3]
with self.cached_session(use_gpu=True):
x = np.array(np.random.randint(5, size=input_shape), dtype=np.int32)
v_tf = array_ops.broadcast_to(constant_op.constant(x), output_shape)
v_np = np.broadcast_to(x, output_shape)
self.assertAllEqual(v_tf.eval(), v_np)
@test_util.run_deprecated_v1
def testBroadcastToScalar(self):
with self.session(use_gpu=True):
@ -88,8 +108,9 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
def testBroadcastScalarToNonScalar(self):
with self.session(use_gpu=True):
x = np.array(1.0, dtype=np.float)
v_tf = array_ops.broadcast_to(constant_op.constant(1.0), [2, 3, 4])
v_np = np.broadcast_to(x, [2, 3, 4])
v_tf = array_ops.broadcast_to(constant_op.constant(1.0), [2, 3, 4,
1, 1, 1])
v_np = np.broadcast_to(x, [2, 3, 4, 1, 1, 1])
self.assertAllEqual(v_tf.eval(), v_np)
@test_util.run_deprecated_v1
@ -148,6 +169,18 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
out, out.get_shape())
self.assertLess(err, 1e-4)
@test_util.run_deprecated_v1
def testGradientWithLargeDim(self):
input_shape = [2, 1, 3, 2, 2, 2, 1, 1, 1]
output_shape = [1, 1, 1, 2, 5, 3, 2, 2, 2, 3, 3, 3]
x = constant_op.constant(np.array(np.random.randn(*input_shape),
dtype=np.float32))
v = array_ops.broadcast_to(x, output_shape)
out = 2 * v
with self.cached_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)
if __name__ == "__main__":
test_lib.main()

View File

@ -839,18 +839,11 @@ def _ScatterNdNonAliasingAddGrad(op, grad):
def _BroadcastToGrad(op, grad):
input_value = op.inputs[0]
broadcast_shape = op.inputs[1]
# Assign ids for each position in input_value.
input_value_shape = array_ops.shape(input_value)
input_value_size = array_ops.size(input_value)
ids = array_ops.reshape(math_ops.range(input_value_size), input_value_shape)
broadcast_ids = array_ops.broadcast_to(ids, broadcast_shape)
# Group by ids and sum its gradients.
grad_flatten = array_ops.reshape(grad, [-1])
broadcast_ids_flatten = array_ops.reshape(broadcast_ids, [-1])
# TODO(apassos): Use reduce_sum for gradient now that we only support
# the usual numpy broadcast semantics.
updates_grad_flatten = math_ops.unsorted_segment_sum(grad_flatten,
broadcast_ids_flatten,
input_value_size)
updates_grad = array_ops.reshape(updates_grad_flatten, input_value_shape)
_, reduction_axes = gen_array_ops.broadcast_gradient_args(broadcast_shape,
input_value_shape)
updates_grad_reshaped = math_ops.reduce_sum(grad,
axis=reduction_axes,
keepdims=True)
updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape)
return [updates_grad, None]