Update gradient of tf.broadcast_to to use tf.reduce_sum instead of tf.unsorted_segment_sum. Add new high rank tests for tf.broadcast_to and its gradient.
PiperOrigin-RevId: 233451123
This commit is contained in:
parent
f35ec21e23
commit
f5593ff762
@ -76,6 +76,26 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
|
|||||||
v_np = np.broadcast_to(x, output_shape)
|
v_np = np.broadcast_to(x, output_shape)
|
||||||
self.assertAllEqual(v_tf.eval(), v_np)
|
self.assertAllEqual(v_tf.eval(), v_np)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testBroadcastToShapeLargerDim(self):
|
||||||
|
input_shape = [2, 1, 3, 2, 2, 2]
|
||||||
|
output_shape = [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 15, 3, 2, 2, 2]
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
x = np.array(np.random.randint(5, size=input_shape), dtype=np.int32)
|
||||||
|
v_tf = array_ops.broadcast_to(constant_op.constant(x), output_shape)
|
||||||
|
v_np = np.broadcast_to(x, output_shape)
|
||||||
|
self.assertAllEqual(v_tf.eval(), v_np)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testBroadcastToShapeLargerDim2(self):
|
||||||
|
input_shape = [2, 1, 3, 2, 2, 2, 1, 1, 1]
|
||||||
|
output_shape = [1, 1, 1, 2, 5, 3, 2, 2, 2, 3, 3, 3]
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
x = np.array(np.random.randint(5, size=input_shape), dtype=np.int32)
|
||||||
|
v_tf = array_ops.broadcast_to(constant_op.constant(x), output_shape)
|
||||||
|
v_np = np.broadcast_to(x, output_shape)
|
||||||
|
self.assertAllEqual(v_tf.eval(), v_np)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testBroadcastToScalar(self):
|
def testBroadcastToScalar(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
@ -88,8 +108,9 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
|
|||||||
def testBroadcastScalarToNonScalar(self):
|
def testBroadcastScalarToNonScalar(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
x = np.array(1.0, dtype=np.float)
|
x = np.array(1.0, dtype=np.float)
|
||||||
v_tf = array_ops.broadcast_to(constant_op.constant(1.0), [2, 3, 4])
|
v_tf = array_ops.broadcast_to(constant_op.constant(1.0), [2, 3, 4,
|
||||||
v_np = np.broadcast_to(x, [2, 3, 4])
|
1, 1, 1])
|
||||||
|
v_np = np.broadcast_to(x, [2, 3, 4, 1, 1, 1])
|
||||||
self.assertAllEqual(v_tf.eval(), v_np)
|
self.assertAllEqual(v_tf.eval(), v_np)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@ -148,6 +169,18 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
|
|||||||
out, out.get_shape())
|
out, out.get_shape())
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGradientWithLargeDim(self):
|
||||||
|
input_shape = [2, 1, 3, 2, 2, 2, 1, 1, 1]
|
||||||
|
output_shape = [1, 1, 1, 2, 5, 3, 2, 2, 2, 3, 3, 3]
|
||||||
|
x = constant_op.constant(np.array(np.random.randn(*input_shape),
|
||||||
|
dtype=np.float32))
|
||||||
|
v = array_ops.broadcast_to(x, output_shape)
|
||||||
|
out = 2 * v
|
||||||
|
with self.cached_session():
|
||||||
|
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
|
||||||
|
out, out.get_shape())
|
||||||
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_lib.main()
|
test_lib.main()
|
||||||
|
@ -839,18 +839,11 @@ def _ScatterNdNonAliasingAddGrad(op, grad):
|
|||||||
def _BroadcastToGrad(op, grad):
|
def _BroadcastToGrad(op, grad):
|
||||||
input_value = op.inputs[0]
|
input_value = op.inputs[0]
|
||||||
broadcast_shape = op.inputs[1]
|
broadcast_shape = op.inputs[1]
|
||||||
# Assign ids for each position in input_value.
|
|
||||||
input_value_shape = array_ops.shape(input_value)
|
input_value_shape = array_ops.shape(input_value)
|
||||||
input_value_size = array_ops.size(input_value)
|
_, reduction_axes = gen_array_ops.broadcast_gradient_args(broadcast_shape,
|
||||||
ids = array_ops.reshape(math_ops.range(input_value_size), input_value_shape)
|
input_value_shape)
|
||||||
broadcast_ids = array_ops.broadcast_to(ids, broadcast_shape)
|
updates_grad_reshaped = math_ops.reduce_sum(grad,
|
||||||
# Group by ids and sum its gradients.
|
axis=reduction_axes,
|
||||||
grad_flatten = array_ops.reshape(grad, [-1])
|
keepdims=True)
|
||||||
broadcast_ids_flatten = array_ops.reshape(broadcast_ids, [-1])
|
updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape)
|
||||||
# TODO(apassos): Use reduce_sum for gradient now that we only support
|
|
||||||
# the usual numpy broadcast semantics.
|
|
||||||
updates_grad_flatten = math_ops.unsorted_segment_sum(grad_flatten,
|
|
||||||
broadcast_ids_flatten,
|
|
||||||
input_value_size)
|
|
||||||
updates_grad = array_ops.reshape(updates_grad_flatten, input_value_shape)
|
|
||||||
return [updates_grad, None]
|
return [updates_grad, None]
|
||||||
|
Loading…
Reference in New Issue
Block a user