Update gradient of tf.broadcast_to to use tf.reduce_sum instead of tf.unsorted_segment_sum. Add new high rank tests for tf.broadcast_to and its gradient.
PiperOrigin-RevId: 233451123
This commit is contained in:
parent
f35ec21e23
commit
f5593ff762
@ -76,6 +76,26 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
|
||||
v_np = np.broadcast_to(x, output_shape)
|
||||
self.assertAllEqual(v_tf.eval(), v_np)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBroadcastToShapeLargerDim(self):
|
||||
input_shape = [2, 1, 3, 2, 2, 2]
|
||||
output_shape = [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 15, 3, 2, 2, 2]
|
||||
with self.cached_session(use_gpu=True):
|
||||
x = np.array(np.random.randint(5, size=input_shape), dtype=np.int32)
|
||||
v_tf = array_ops.broadcast_to(constant_op.constant(x), output_shape)
|
||||
v_np = np.broadcast_to(x, output_shape)
|
||||
self.assertAllEqual(v_tf.eval(), v_np)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBroadcastToShapeLargerDim2(self):
|
||||
input_shape = [2, 1, 3, 2, 2, 2, 1, 1, 1]
|
||||
output_shape = [1, 1, 1, 2, 5, 3, 2, 2, 2, 3, 3, 3]
|
||||
with self.cached_session(use_gpu=True):
|
||||
x = np.array(np.random.randint(5, size=input_shape), dtype=np.int32)
|
||||
v_tf = array_ops.broadcast_to(constant_op.constant(x), output_shape)
|
||||
v_np = np.broadcast_to(x, output_shape)
|
||||
self.assertAllEqual(v_tf.eval(), v_np)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBroadcastToScalar(self):
|
||||
with self.session(use_gpu=True):
|
||||
@ -88,8 +108,9 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
|
||||
def testBroadcastScalarToNonScalar(self):
|
||||
with self.session(use_gpu=True):
|
||||
x = np.array(1.0, dtype=np.float)
|
||||
v_tf = array_ops.broadcast_to(constant_op.constant(1.0), [2, 3, 4])
|
||||
v_np = np.broadcast_to(x, [2, 3, 4])
|
||||
v_tf = array_ops.broadcast_to(constant_op.constant(1.0), [2, 3, 4,
|
||||
1, 1, 1])
|
||||
v_np = np.broadcast_to(x, [2, 3, 4, 1, 1, 1])
|
||||
self.assertAllEqual(v_tf.eval(), v_np)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@ -148,6 +169,18 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
|
||||
out, out.get_shape())
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientWithLargeDim(self):
|
||||
input_shape = [2, 1, 3, 2, 2, 2, 1, 1, 1]
|
||||
output_shape = [1, 1, 1, 2, 5, 3, 2, 2, 2, 3, 3, 3]
|
||||
x = constant_op.constant(np.array(np.random.randn(*input_shape),
|
||||
dtype=np.float32))
|
||||
v = array_ops.broadcast_to(x, output_shape)
|
||||
out = 2 * v
|
||||
with self.cached_session():
|
||||
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
|
||||
out, out.get_shape())
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lib.main()
|
||||
|
@ -839,18 +839,11 @@ def _ScatterNdNonAliasingAddGrad(op, grad):
|
||||
def _BroadcastToGrad(op, grad):
|
||||
input_value = op.inputs[0]
|
||||
broadcast_shape = op.inputs[1]
|
||||
# Assign ids for each position in input_value.
|
||||
input_value_shape = array_ops.shape(input_value)
|
||||
input_value_size = array_ops.size(input_value)
|
||||
ids = array_ops.reshape(math_ops.range(input_value_size), input_value_shape)
|
||||
broadcast_ids = array_ops.broadcast_to(ids, broadcast_shape)
|
||||
# Group by ids and sum its gradients.
|
||||
grad_flatten = array_ops.reshape(grad, [-1])
|
||||
broadcast_ids_flatten = array_ops.reshape(broadcast_ids, [-1])
|
||||
# TODO(apassos): Use reduce_sum for gradient now that we only support
|
||||
# the usual numpy broadcast semantics.
|
||||
updates_grad_flatten = math_ops.unsorted_segment_sum(grad_flatten,
|
||||
broadcast_ids_flatten,
|
||||
input_value_size)
|
||||
updates_grad = array_ops.reshape(updates_grad_flatten, input_value_shape)
|
||||
_, reduction_axes = gen_array_ops.broadcast_gradient_args(broadcast_shape,
|
||||
input_value_shape)
|
||||
updates_grad_reshaped = math_ops.reduce_sum(grad,
|
||||
axis=reduction_axes,
|
||||
keepdims=True)
|
||||
updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape)
|
||||
return [updates_grad, None]
|
||||
|
Loading…
Reference in New Issue
Block a user