Check that the Bcast result shape is equal to the output_shape in broadcast_to

PiperOrigin-RevId: 242557631
This commit is contained in:
Akshay Modi 2019-04-08 16:05:45 -07:00 committed by TensorFlower Gardener
parent 0eb15926b8
commit c1ff471dec
2 changed files with 14 additions and 0 deletions

View File

@ -80,6 +80,10 @@ class BroadcastToOp : public OpKernel {
errors::InvalidArgument( errors::InvalidArgument(
"Incompatible shapes: ", input_shape.DebugString(), " vs. ", "Incompatible shapes: ", input_shape.DebugString(), " vs. ",
output_shape.DebugString())); output_shape.DebugString()));
OP_REQUIRES(ctx, BCast::ToShape(bcast.output_shape()) == output_shape,
errors::InvalidArgument("Unable to broadcast tensor of shape ",
input_shape, " to tensor of shape ",
output_shape));
functor::BroadcastTo<Device, T>()(device, ctx, *output_tensor, output_shape, functor::BroadcastTo<Device, T>()(device, ctx, *output_tensor, output_shape,
input_tensor, input_shape, bcast); input_tensor, input_shape, bcast);

View File

@ -19,8 +19,10 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradient_checker
@ -127,6 +129,14 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
# check shape inference when shape input is constant # check shape inference when shape input is constant
self.assertAllEqual(shape, v_np.shape) self.assertAllEqual(shape, v_np.shape)
def testBroadcastToBadOutputShape(self):
with context.eager_mode():
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Unable to broadcast tensor of shape"):
self.evaluate(
array_ops.broadcast_to(
constant_op.constant([0, 1]), constant_op.constant([2, 1])))
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testGradientForScalar(self): def testGradientForScalar(self):
x = constant_op.constant(1, dtype=dtypes.float32) x = constant_op.constant(1, dtype=dtypes.float32)