diff --git a/tensorflow/core/kernels/broadcast_to_op.cc b/tensorflow/core/kernels/broadcast_to_op.cc index a1e5025ee87..ab9961ced9c 100644 --- a/tensorflow/core/kernels/broadcast_to_op.cc +++ b/tensorflow/core/kernels/broadcast_to_op.cc @@ -80,6 +80,10 @@ class BroadcastToOp : public OpKernel { errors::InvalidArgument( "Incompatible shapes: ", input_shape.DebugString(), " vs. ", output_shape.DebugString())); + OP_REQUIRES(ctx, BCast::ToShape(bcast.output_shape()) == output_shape, + errors::InvalidArgument("Unable to broadcast tensor of shape ", + input_shape, " to tensor of shape ", + output_shape)); functor::BroadcastTo()(device, ctx, *output_tensor, output_shape, input_tensor, input_shape, bcast); diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py index 3c5433cb899..f478ee9f643 100644 --- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py +++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py @@ -19,8 +19,10 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker @@ -127,6 +129,14 @@ class BroadcastToTest(test_util.TensorFlowTestCase): # check shape inference when shape input is constant self.assertAllEqual(shape, v_np.shape) + def testBroadcastToBadOutputShape(self): + with context.eager_mode(): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Unable to broadcast tensor of shape"): + self.evaluate( + array_ops.broadcast_to( + constant_op.constant([0, 1]), constant_op.constant([2, 1]))) + @test_util.run_deprecated_v1 def testGradientForScalar(self): x = constant_op.constant(1, dtype=dtypes.float32)